From f1d20905d0f94fedc7d5fb200f6176e3a516e2f2 Mon Sep 17 00:00:00 2001 From: bagas Date: Sun, 25 Jan 2026 20:49:12 +0700 Subject: [PATCH] refactor(forwarder): remove CreateForwardedTCPIPPayload method - OpenForwardedChannel now privately calls CreateForwardedTCPIPPayload - Removed an unused function --- internal/transport/httphandler.go | 3 +-- internal/transport/httphandler_test.go | 25 ++++++++----------------- internal/transport/tcp.go | 6 ++---- internal/transport/tcp_test.go | 17 ++++------------- session/forwarder/forwarder.go | 19 +++++-------------- session/lifecycle/lifecycle_test.go | 4 ---- 6 files changed, 20 insertions(+), 54 deletions(-) diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 7e2135a..67aa6fb 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -152,10 +152,9 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { } func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, payload) + channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr()) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) return diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go index b6df77f..de1dd24 100644 --- a/internal/transport/httphandler_test.go +++ b/internal/transport/httphandler_test.go @@ -176,15 +176,6 @@ type MockForwarder struct { mock.Mock } -func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { - args := m.Called(origin) - return args.Get(0).([]byte) -} - -func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) { - m.Called(dst) -} - func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { m.Called(dst, src) } @@ -221,8 +212,8 @@ func (m *MockForwarder) Listener() net.Listener { return args.Get(0).(net.Listener) } -func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { - args := m.Called(ctx, payload) +func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(ctx, origin) if args.Get(0) == nil { return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) } @@ -413,7 +404,7 @@ func TestHandler(t *testing.T) { mockSession.On("Forwarder").Return(mockForwarder) mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) - mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed")) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed")) }, }, { @@ -437,7 +428,7 @@ func TestHandler(t *testing.T) { mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) reqCh := make(chan *ssh.Request) - mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mockSSHChannel.On("Close").Return(nil) @@ -469,7 +460,7 @@ func TestHandler(t *testing.T) { mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) reqCh := make(chan *ssh.Request) - mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Close").Return(nil) @@ -577,7 +568,7 @@ func TestHandler(t *testing.T) { mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) reqCh := make(chan *ssh.Request) - mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Close").Return(nil) @@ -640,7 +631,7 @@ func TestHandler(t *testing.T) { mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) reqCh := make(chan *ssh.Request) - mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) mockSSHChannel.On("Write", mock.Anything).Return(0, nil) mockSSHChannel.On("Close").Return(fmt.Errorf("close error")) @@ -666,7 +657,7 @@ func TestHandler(t *testing.T) { mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) - mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Run(func(args mock.Arguments) { + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { ctx := args.Get(0).(context.Context) <-ctx.Done() }).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded) diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go index 26c61b7..34ea3c2 100644 --- a/internal/transport/tcp.go +++ b/internal/transport/tcp.go @@ -18,8 +18,7 @@ type tcp struct { } type Forwarder interface { - CreateForwardedTCPIPPayload(origin net.Addr) []byte - OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) + OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) HandleConnection(dst io.ReadWriter, src ssh.Channel) } @@ -55,10 +54,9 @@ func (tt *tcp) handleTcp(conn net.Conn) { log.Printf("Failed to close connection: %v", err) } }() - payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr()) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, payload) + channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr()) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) return diff --git a/internal/transport/tcp_test.go b/internal/transport/tcp_test.go index e8a5790..409e6f1 100644 --- a/internal/transport/tcp_test.go +++ b/internal/transport/tcp_test.go @@ -72,10 +72,8 @@ func TestTCPServer_Serve_Success(t *testing.T) { assert.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port - payload := []byte("test-payload") - mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) reqs := make(chan *ssh.Request) - mf.On("OpenForwardedChannel", mock.Anything, payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil) + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil) mf.On("HandleConnection", mock.Anything, mock.Anything).Return() go func() { @@ -99,12 +97,9 @@ func TestTCPServer_handleTcp_Success(t *testing.T) { serverConn, clientConn := net.Pipe() defer clientConn.Close() - payload := []byte("test-payload") - mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) - reqs := make(chan *ssh.Request) mockChannel := new(MockSSHChannel) - mf.On("OpenForwardedChannel", mock.Anything, payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil) + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil) mf.On("HandleConnection", serverConn, mockChannel).Return() @@ -121,9 +116,7 @@ func TestTCPServer_handleTcp_CloseError(t *testing.T) { mc.On("Close").Return(errors.New("close error")) mc.On("RemoteAddr").Return(&net.TCPAddr{}) - payload := []byte("test-payload") - mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) - mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) srv.handleTcp(mc) mc.AssertExpectations(t) @@ -136,9 +129,7 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) { serverConn, clientConn := net.Pipe() defer clientConn.Close() - payload := []byte("test-payload") - mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) - mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) srv.handleTcp(serverConn) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 1520bb6..629fffd 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -24,9 +24,7 @@ type Forwarder interface { TunnelType() types.TunnelType ForwardedPort() uint16 HandleConnection(dst io.ReadWriter, src ssh.Channel) - CreateForwardedTCPIPPayload(origin net.Addr) []byte - OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) - WriteBadGatewayResponse(dst io.Writer) + OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) Close() error } type forwarder struct { @@ -60,7 +58,8 @@ func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, return io.CopyBuffer(dst, src, buf) } -func (f *forwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { +func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { + payload := createForwardedTCPIPPayload(origin, f.forwardedPort) type channelResult struct { channel ssh.Channel reqs <-chan *ssh.Request @@ -171,14 +170,6 @@ func (f *forwarder) Listener() net.Listener { return f.listener } -func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) { - _, err := dst.Write(types.BadGatewayResponse) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return - } -} - func (f *forwarder) Close() error { if f.Listener() != nil { return f.listener.Close() @@ -186,7 +177,7 @@ func (f *forwarder) Close() error { return nil } -func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { +func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte { host, portStr, _ := net.SplitHostPort(origin.String()) port, _ := strconv.Atoi(portStr) @@ -197,7 +188,7 @@ func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { OriginPort uint32 }{ DestAddr: "localhost", - DestPort: uint32(f.ForwardedPort()), + DestPort: uint32(destPort), OriginAddr: host, OriginPort: uint32(port), } diff --git a/session/lifecycle/lifecycle_test.go b/session/lifecycle/lifecycle_test.go index 5a8deb0..73333e5 100644 --- a/session/lifecycle/lifecycle_test.go +++ b/session/lifecycle/lifecycle_test.go @@ -29,10 +29,6 @@ func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { return args.Get(0).([]byte) } -func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) { - m.Called(dst) -} - func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { m.Called(dst, src) }