refactor: remove duplicate channel management helpers from HTTP handler
SonarQube Scan / SonarQube Trigger (push) Successful in 2m12s

This commit is contained in:
2026-01-25 13:47:57 +07:00
parent 2b488a5ab5
commit 8b44e4db4e
6 changed files with 43 additions and 139 deletions
+19 -62
View File
@@ -2,6 +2,7 @@ package transport
import (
"bytes"
"context"
"fmt"
"io"
"net"
@@ -220,8 +221,8 @@ func (m *MockForwarder) Listener() net.Listener {
return args.Get(0).(net.Listener)
}
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(payload)
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(ctx, payload)
if args.Get(0) == nil {
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
@@ -358,12 +359,10 @@ func TestHandler(t *testing.T) {
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
msr.On("Get", types.SessionKey{
Id: "test",
@@ -371,15 +370,9 @@ func TestHandler(t *testing.T) {
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
w := args.Get(0).(io.Writer)
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
})
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
},
},
{
@@ -391,8 +384,6 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
@@ -401,13 +392,11 @@ func TestHandler(t *testing.T) {
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mockSSHChannel.On("Close").Return(nil)
@@ -427,8 +416,6 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
@@ -437,13 +424,11 @@ func TestHandler(t *testing.T) {
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
@@ -524,18 +509,15 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
@@ -560,19 +542,16 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
return k.Id == "test"
})).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockLifecycle := new(MockLifecycle)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockSSHConn := new(MockSSHConn)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockSSHChannel := new(MockSSHChannel)
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Close").Return(nil)
},
setupConn: func() (net.Conn, net.Conn) {
@@ -592,18 +571,14 @@ func TestHandler(t *testing.T) {
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
@@ -619,38 +594,20 @@ func TestHandler(t *testing.T) {
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
oldTimeout := openChannelTimeout
openChannelTimeout = 10 * time.Millisecond
t.Cleanup(func() {
openChannelTimeout = oldTimeout
time.Sleep(100 * time.Millisecond)
})
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
w := args.Get(0).(io.Writer)
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
})
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) {
time.Sleep(50 * time.Millisecond)
}).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error"))
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Run(func(args mock.Arguments) {
ctx := args.Get(0).(context.Context)
<-ctx.Done()
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
},
},
}
@@ -740,7 +697,7 @@ func TestHandler(t *testing.T) {
} else {
assert.Equal(t, string(tt.expected), string(response))
}
case <-time.After(2 * time.Second):
case <-time.After(10 * time.Second):
if clientConn != nil {
t.Fatal("Test timeout - no response received")
}