revert-54069ad305 #11

Closed
bagas wants to merge 217 commits from revert-54069ad305 into main
6 changed files with 43 additions and 139 deletions
Showing only changes of commit 9785a97973 - Show all commits
+8 -50
View File
@@ -2,6 +2,7 @@ package transport
import ( import (
"bufio" "bufio"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -19,8 +20,6 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var openChannelTimeout = 5 * time.Second
type httpHandler struct { type httpHandler struct {
domain string domain string
sessionRegistry registry.Registry sessionRegistry registry.Registry
@@ -139,13 +138,17 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
} }
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession) payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, payload)
if err != nil { if err != nil {
log.Printf("Failed to establish channel: %v", err) log.Printf("Failed to open forwarded-tcpip channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return return
} }
go ssh.DiscardRequests(reqs)
defer func() { defer func() {
err = channel.Close() err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
@@ -162,51 +165,6 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ
sshSession.Forwarder().HandleConnection(hw, channel) sshSession.Forwarder().HandleConnection(hw, channel)
} }
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hh.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(openChannelTimeout):
go func() {
result := <-resultChan
hh.cleanupUnusedChannel(result.channel, result.reqs)
}()
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint() fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
+19 -62
View File
@@ -2,6 +2,7 @@ package transport
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
@@ -220,8 +221,8 @@ func (m *MockForwarder) Listener() net.Listener {
return args.Get(0).(net.Listener) return args.Get(0).(net.Listener)
} }
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(payload) args := m.Called(ctx, payload)
if args.Get(0) == nil { if args.Get(0) == nil {
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
} }
@@ -358,12 +359,10 @@ func TestHandler(t *testing.T) {
isTLS: true, isTLS: true,
redirectTLS: false, redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"), expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession) mockSession := new(MockSession)
mockForwarder := new(MockForwarder) mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
msr.On("Get", types.SessionKey{ msr.On("Get", types.SessionKey{
Id: "test", Id: "test",
@@ -371,15 +370,9 @@ func TestHandler(t *testing.T) {
}).Return(mockSession, nil) }).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder) mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed")) mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
w := args.Get(0).(io.Writer)
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
})
}, },
}, },
{ {
@@ -391,8 +384,6 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession) mockSession := new(MockSession)
mockForwarder := new(MockForwarder) mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel) mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{ msr.On("Get", types.SessionKey{
@@ -401,13 +392,11 @@ func TestHandler(t *testing.T) {
}).Return(mockSession, nil) }).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder) mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request) reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mockSSHChannel.On("Close").Return(nil) mockSSHChannel.On("Close").Return(nil)
@@ -427,8 +416,6 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession) mockSession := new(MockSession)
mockForwarder := new(MockForwarder) mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel) mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{ msr.On("Get", types.SessionKey{
@@ -437,13 +424,11 @@ func TestHandler(t *testing.T) {
}).Return(mockSession, nil) }).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder) mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request) reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil) mockSSHChannel.On("Close").Return(nil)
@@ -524,18 +509,15 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession) mockSession := new(MockSession)
mockForwarder := new(MockForwarder) mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel) mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil) msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder) mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request) reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil) mockSSHChannel.On("Close").Return(nil)
@@ -560,19 +542,16 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession) mockSession := new(MockSession)
mockForwarder := new(MockForwarder) mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool { msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
return k.Id == "test" return k.Id == "test"
})).Return(mockSession, nil) })).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder) mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockLifecycle := new(MockLifecycle)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockSSHConn := new(MockSSHConn)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockSSHChannel := new(MockSSHChannel)
reqCh := make(chan *ssh.Request) reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Close").Return(nil) mockSSHChannel.On("Close").Return(nil)
}, },
setupConn: func() (net.Conn, net.Conn) { setupConn: func() (net.Conn, net.Conn) {
@@ -592,18 +571,14 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession) mockSession := new(MockSession)
mockForwarder := new(MockForwarder) mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel) mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil) msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder) mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request) reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(fmt.Errorf("close error")) mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
@@ -619,38 +594,20 @@ func TestHandler(t *testing.T) {
isTLS: true, isTLS: true,
redirectTLS: false, redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"), expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
oldTimeout := openChannelTimeout
openChannelTimeout = 10 * time.Millisecond
t.Cleanup(func() {
openChannelTimeout = oldTimeout
time.Sleep(100 * time.Millisecond)
})
mockSession := new(MockSession) mockSession := new(MockSession)
mockForwarder := new(MockForwarder) mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil) msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder) mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
w := args.Get(0).(io.Writer)
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
})
reqCh := make(chan *ssh.Request) mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Run(func(args mock.Arguments) {
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) { ctx := args.Get(0).(context.Context)
time.Sleep(50 * time.Millisecond) <-ctx.Done()
}).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) }).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error"))
}, },
}, },
} }
@@ -740,7 +697,7 @@ func TestHandler(t *testing.T) {
} else { } else {
assert.Equal(t, string(tt.expected), string(response)) assert.Equal(t, string(tt.expected), string(response))
} }
case <-time.After(2 * time.Second): case <-time.After(10 * time.Second):
if clientConn != nil { if clientConn != nil {
t.Fatal("Test timeout - no response received") t.Fatal("Test timeout - no response received")
} }
+6 -2
View File
@@ -1,11 +1,13 @@
package transport package transport
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -17,7 +19,7 @@ type tcp struct {
type Forwarder interface { type Forwarder interface {
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel) HandleConnection(dst io.ReadWriter, src ssh.Channel)
} }
@@ -54,7 +56,9 @@ func (tt *tcp) handleTcp(conn net.Conn) {
} }
}() }()
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr()) payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, payload)
if err != nil { if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err) log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return return
+4 -4
View File
@@ -75,7 +75,7 @@ func TestTCPServer_Serve_Success(t *testing.T) {
payload := []byte("test-payload") payload := []byte("test-payload")
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
reqs := make(chan *ssh.Request) reqs := make(chan *ssh.Request)
mf.On("OpenForwardedChannel", payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil) mf.On("OpenForwardedChannel", mock.Anything, payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", mock.Anything, mock.Anything).Return() mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
go func() { go func() {
@@ -104,7 +104,7 @@ func TestTCPServer_handleTcp_Success(t *testing.T) {
reqs := make(chan *ssh.Request) reqs := make(chan *ssh.Request)
mockChannel := new(MockSSHChannel) mockChannel := new(MockSSHChannel)
mf.On("OpenForwardedChannel", payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil) mf.On("OpenForwardedChannel", mock.Anything, payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", serverConn, mockChannel).Return() mf.On("HandleConnection", serverConn, mockChannel).Return()
@@ -123,7 +123,7 @@ func TestTCPServer_handleTcp_CloseError(t *testing.T) {
payload := []byte("test-payload") payload := []byte("test-payload")
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(mc) srv.handleTcp(mc)
mc.AssertExpectations(t) mc.AssertExpectations(t)
@@ -138,7 +138,7 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
payload := []byte("test-payload") payload := []byte("test-payload")
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(serverConn) srv.handleTcp(serverConn)
-15
View File
@@ -1,15 +0,0 @@
package transport
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTransportInterface(t *testing.T) {
var _ Transport = (*httpServer)(nil)
var _ Transport = (*https)(nil)
var _ Transport = (*tcp)(nil)
assert.True(t, true)
}
+6 -6
View File
@@ -1,6 +1,7 @@
package forwarder package forwarder
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -8,7 +9,6 @@ import (
"net" "net"
"strconv" "strconv"
"sync" "sync"
"time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
@@ -25,7 +25,7 @@ type Forwarder interface {
ForwardedPort() uint16 ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel) HandleConnection(dst io.ReadWriter, src ssh.Channel)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
Close() error Close() error
} }
@@ -60,7 +60,7 @@ func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64,
return io.CopyBuffer(dst, src, buf) return io.CopyBuffer(dst, src, buf)
} }
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { func (f *forwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
reqs <-chan *ssh.Request reqs <-chan *ssh.Request
@@ -72,7 +72,7 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select { select {
case resultChan <- channelResult{channel, reqs, err}: case resultChan <- channelResult{channel, reqs, err}:
default: case <-ctx.Done():
if channel != nil { if channel != nil {
err = channel.Close() err = channel.Close()
if err != nil { if err != nil {
@@ -87,8 +87,8 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
select { select {
case result := <-resultChan: case result := <-resultChan:
return result.channel, result.reqs, result.err return result.channel, result.reqs, result.err
case <-time.After(5 * time.Second): case <-ctx.Done():
return nil, nil, errors.New("timeout opening forwarded-tcpip channel") return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
} }
} }