diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 2a43f95..411a3f9 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -2,6 +2,7 @@ package transport import ( "bufio" + "context" "errors" "fmt" "io" @@ -19,8 +20,6 @@ import ( "golang.org/x/crypto/ssh" ) -var openChannelTimeout = 5 * time.Second - type httpHandler struct { domain string sessionRegistry registry.Registry @@ -139,13 +138,17 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { } func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { - channel, err := hh.openForwardedChannel(hw, sshSession) + payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, payload) if err != nil { - log.Printf("Failed to establish channel: %v", err) - sshSession.Forwarder().WriteBadGatewayResponse(hw) + log.Printf("Failed to open forwarded-tcpip channel: %v", err) return } + go ssh.DiscardRequests(reqs) + defer func() { err = channel.Close() if err != nil && !errors.Is(err, io.EOF) { @@ -162,51 +165,6 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ sshSession.Forwarder().HandleConnection(hw, channel) } -func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - select { - case resultChan <- channelResult{channel, reqs, err}: - default: - hh.cleanupUnusedChannel(channel, reqs) - } - }() - - select { - case result := <-resultChan: - if result.err != nil { - return nil, result.err - } - go ssh.DiscardRequests(result.reqs) - return result.channel, nil - case <-time.After(openChannelTimeout): - go func() { - result := <-resultChan - hh.cleanupUnusedChannel(result.channel, result.reqs) - }() - return nil, errors.New("timeout opening forwarded-tcpip channel") - } -} - -func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { - if channel != nil { - if err := channel.Close(); err != nil { - log.Printf("Failed to close unused channel: %v", err) - } - go ssh.DiscardRequests(reqs) - } -} - func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { fingerprintMiddleware := middleware.NewTunnelFingerprint() forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go index 984f9b1..b4e592a 100644 --- a/internal/transport/httphandler_test.go +++ b/internal/transport/httphandler_test.go @@ -2,6 +2,7 @@ package transport import ( "bytes" + "context" "fmt" "io" "net" @@ -220,8 +221,8 @@ func (m *MockForwarder) Listener() net.Listener { return args.Get(0).(net.Listener) } -func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { - args := m.Called(payload) +func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(ctx, payload) if args.Get(0) == nil { return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) } @@ -358,12 +359,10 @@ func TestHandler(t *testing.T) { isTLS: true, redirectTLS: false, request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), - expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"), + expected: []byte(""), setupMocks: func(msr *MockSessionRegistry) { mockSession := new(MockSession) mockForwarder := new(MockForwarder) - mockLifecycle := new(MockLifecycle) - mockSSHConn := new(MockSSHConn) msr.On("Get", types.SessionKey{ Id: "test", @@ -371,15 +370,9 @@ func TestHandler(t *testing.T) { }).Return(mockSession, nil) mockSession.On("Forwarder").Return(mockForwarder) - mockSession.On("Lifecycle").Return(mockLifecycle) - mockLifecycle.On("Connection").Return(mockSSHConn) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) - mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed")) - mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) { - w := args.Get(0).(io.Writer) - _, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) - }) + mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed")) }, }, { @@ -391,8 +384,6 @@ func TestHandler(t *testing.T) { setupMocks: func(msr *MockSessionRegistry) { mockSession := new(MockSession) mockForwarder := new(MockForwarder) - mockLifecycle := new(MockLifecycle) - mockSSHConn := new(MockSSHConn) mockSSHChannel := new(MockSSHChannel) msr.On("Get", types.SessionKey{ @@ -401,13 +392,11 @@ func TestHandler(t *testing.T) { }).Return(mockSession, nil) mockSession.On("Forwarder").Return(mockForwarder) - mockSession.On("Lifecycle").Return(mockLifecycle) - mockLifecycle.On("Connection").Return(mockSSHConn) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) reqCh := make(chan *ssh.Request) - mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mockSSHChannel.On("Close").Return(nil) @@ -427,8 +416,6 @@ func TestHandler(t *testing.T) { setupMocks: func(msr *MockSessionRegistry) { mockSession := new(MockSession) mockForwarder := new(MockForwarder) - mockLifecycle := new(MockLifecycle) - mockSSHConn := new(MockSSHConn) mockSSHChannel := new(MockSSHChannel) msr.On("Get", types.SessionKey{ @@ -437,13 +424,11 @@ func TestHandler(t *testing.T) { }).Return(mockSession, nil) mockSession.On("Forwarder").Return(mockForwarder) - mockSession.On("Lifecycle").Return(mockLifecycle) - mockLifecycle.On("Connection").Return(mockSSHConn) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) reqCh := make(chan *ssh.Request) - mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Close").Return(nil) @@ -524,18 +509,15 @@ func TestHandler(t *testing.T) { setupMocks: func(msr *MockSessionRegistry) { mockSession := new(MockSession) mockForwarder := new(MockForwarder) - mockLifecycle := new(MockLifecycle) - mockSSHConn := new(MockSSHConn) mockSSHChannel := new(MockSSHChannel) msr.On("Get", mock.Anything).Return(mockSession, nil) mockSession.On("Forwarder").Return(mockForwarder) - mockSession.On("Lifecycle").Return(mockLifecycle) - mockLifecycle.On("Connection").Return(mockSSHConn) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + reqCh := make(chan *ssh.Request) - mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Close").Return(nil) @@ -560,19 +542,16 @@ func TestHandler(t *testing.T) { setupMocks: func(msr *MockSessionRegistry) { mockSession := new(MockSession) mockForwarder := new(MockForwarder) + mockSSHChannel := new(MockSSHChannel) + msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool { return k.Id == "test" })).Return(mockSession, nil) mockSession.On("Forwarder").Return(mockForwarder) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) - mockLifecycle := new(MockLifecycle) - mockSession.On("Lifecycle").Return(mockLifecycle) - mockSSHConn := new(MockSSHConn) - mockLifecycle.On("Connection").Return(mockSSHConn) - mockSSHChannel := new(MockSSHChannel) reqCh := make(chan *ssh.Request) - mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Close").Return(nil) }, setupConn: func() (net.Conn, net.Conn) { @@ -592,18 +571,14 @@ func TestHandler(t *testing.T) { setupMocks: func(msr *MockSessionRegistry) { mockSession := new(MockSession) mockForwarder := new(MockForwarder) - mockLifecycle := new(MockLifecycle) - mockSSHConn := new(MockSSHConn) mockSSHChannel := new(MockSSHChannel) msr.On("Get", mock.Anything).Return(mockSession, nil) mockSession.On("Forwarder").Return(mockForwarder) - mockSession.On("Lifecycle").Return(mockLifecycle) - mockLifecycle.On("Connection").Return(mockSSHConn) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) reqCh := make(chan *ssh.Request) - mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Close").Return(fmt.Errorf("close error")) @@ -619,38 +594,20 @@ func TestHandler(t *testing.T) { isTLS: true, redirectTLS: false, request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), - expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"), + expected: []byte(""), setupMocks: func(msr *MockSessionRegistry) { - oldTimeout := openChannelTimeout - openChannelTimeout = 10 * time.Millisecond - t.Cleanup(func() { - openChannelTimeout = oldTimeout - time.Sleep(100 * time.Millisecond) - }) - mockSession := new(MockSession) mockForwarder := new(MockForwarder) - mockLifecycle := new(MockLifecycle) - mockSSHConn := new(MockSSHConn) - mockSSHChannel := new(MockSSHChannel) msr.On("Get", mock.Anything).Return(mockSession, nil) mockSession.On("Forwarder").Return(mockForwarder) - mockSession.On("Lifecycle").Return(mockLifecycle) - mockLifecycle.On("Connection").Return(mockSSHConn) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) - mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) { - w := args.Get(0).(io.Writer) - _, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) - }) - reqCh := make(chan *ssh.Request) - mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) { - time.Sleep(50 * time.Millisecond) - }).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) - - mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error")) + mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Run(func(args mock.Arguments) { + ctx := args.Get(0).(context.Context) + <-ctx.Done() + }).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded) }, }, } @@ -740,7 +697,7 @@ func TestHandler(t *testing.T) { } else { assert.Equal(t, string(tt.expected), string(response)) } - case <-time.After(2 * time.Second): + case <-time.After(10 * time.Second): if clientConn != nil { t.Fatal("Test timeout - no response received") } diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go index 9ea2354..26c61b7 100644 --- a/internal/transport/tcp.go +++ b/internal/transport/tcp.go @@ -1,11 +1,13 @@ package transport import ( + "context" "errors" "fmt" "io" "log" "net" + "time" "golang.org/x/crypto/ssh" ) @@ -17,7 +19,7 @@ type tcp struct { type Forwarder interface { CreateForwardedTCPIPPayload(origin net.Addr) []byte - OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) + OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) HandleConnection(dst io.ReadWriter, src ssh.Channel) } @@ -54,7 +56,9 @@ func (tt *tcp) handleTcp(conn net.Conn) { } }() payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) return diff --git a/internal/transport/tcp_test.go b/internal/transport/tcp_test.go index d7c3f8b..e8a5790 100644 --- a/internal/transport/tcp_test.go +++ b/internal/transport/tcp_test.go @@ -75,7 +75,7 @@ func TestTCPServer_Serve_Success(t *testing.T) { payload := []byte("test-payload") mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) reqs := make(chan *ssh.Request) - mf.On("OpenForwardedChannel", payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil) + mf.On("OpenForwardedChannel", mock.Anything, payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil) mf.On("HandleConnection", mock.Anything, mock.Anything).Return() go func() { @@ -104,7 +104,7 @@ func TestTCPServer_handleTcp_Success(t *testing.T) { reqs := make(chan *ssh.Request) mockChannel := new(MockSSHChannel) - mf.On("OpenForwardedChannel", payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil) + mf.On("OpenForwardedChannel", mock.Anything, payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil) mf.On("HandleConnection", serverConn, mockChannel).Return() @@ -123,7 +123,7 @@ func TestTCPServer_handleTcp_CloseError(t *testing.T) { payload := []byte("test-payload") mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) - mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) srv.handleTcp(mc) mc.AssertExpectations(t) @@ -138,7 +138,7 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) { payload := []byte("test-payload") mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) - mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) srv.handleTcp(serverConn) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go deleted file mode 100644 index 016445d..0000000 --- a/internal/transport/transport_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package transport - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestTransportInterface(t *testing.T) { - var _ Transport = (*httpServer)(nil) - var _ Transport = (*https)(nil) - var _ Transport = (*tcp)(nil) - - assert.True(t, true) -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 43bde3e..1520bb6 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -1,6 +1,7 @@ package forwarder import ( + "context" "errors" "fmt" "io" @@ -8,7 +9,6 @@ import ( "net" "strconv" "sync" - "time" "tunnel_pls/internal/config" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -25,7 +25,7 @@ type Forwarder interface { ForwardedPort() uint16 HandleConnection(dst io.ReadWriter, src ssh.Channel) CreateForwardedTCPIPPayload(origin net.Addr) []byte - OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) + OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) WriteBadGatewayResponse(dst io.Writer) Close() error } @@ -60,7 +60,7 @@ func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, return io.CopyBuffer(dst, src, buf) } -func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { +func (f *forwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { type channelResult struct { channel ssh.Channel reqs <-chan *ssh.Request @@ -72,7 +72,7 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) select { case resultChan <- channelResult{channel, reqs, err}: - default: + case <-ctx.Done(): if channel != nil { err = channel.Close() if err != nil { @@ -87,8 +87,8 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s select { case result := <-resultChan: return result.channel, result.reqs, result.err - case <-time.After(5 * time.Second): - return nil, nil, errors.New("timeout opening forwarded-tcpip channel") + case <-ctx.Done(): + return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err()) } }