refactor: remove duplicate channel management helpers from HTTP handler
SonarQube Scan / SonarQube Trigger (push) Successful in 2m12s
SonarQube Scan / SonarQube Trigger (push) Successful in 2m12s
This commit is contained in:
@@ -2,6 +2,7 @@ package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -220,8 +221,8 @@ func (m *MockForwarder) Listener() net.Listener {
|
||||
return args.Get(0).(net.Listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(payload)
|
||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(ctx, payload)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
@@ -358,12 +359,10 @@ func TestHandler(t *testing.T) {
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
@@ -371,15 +370,9 @@ func TestHandler(t *testing.T) {
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.Writer)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
||||
})
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -391,8 +384,6 @@ func TestHandler(t *testing.T) {
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
@@ -401,13 +392,11 @@ func TestHandler(t *testing.T) {
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
@@ -427,8 +416,6 @@ func TestHandler(t *testing.T) {
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
@@ -437,13 +424,11 @@ func TestHandler(t *testing.T) {
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
@@ -524,18 +509,15 @@ func TestHandler(t *testing.T) {
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
@@ -560,19 +542,16 @@ func TestHandler(t *testing.T) {
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
||||
return k.Id == "test"
|
||||
})).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
@@ -592,18 +571,14 @@ func TestHandler(t *testing.T) {
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
||||
@@ -619,38 +594,20 @@ func TestHandler(t *testing.T) {
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
oldTimeout := openChannelTimeout
|
||||
openChannelTimeout = 10 * time.Millisecond
|
||||
t.Cleanup(func() {
|
||||
openChannelTimeout = oldTimeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.Writer)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
||||
})
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error"))
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Run(func(args mock.Arguments) {
|
||||
ctx := args.Get(0).(context.Context)
|
||||
<-ctx.Done()
|
||||
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -740,7 +697,7 @@ func TestHandler(t *testing.T) {
|
||||
} else {
|
||||
assert.Equal(t, string(tt.expected), string(response))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
case <-time.After(10 * time.Second):
|
||||
if clientConn != nil {
|
||||
t.Fatal("Test timeout - no response received")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user