refactor(forwarder): remove CreateForwardedTCPIPPayload method
- OpenForwardedChannel now privately calls CreateForwardedTCPIPPayload - Removed an unused function
This commit is contained in:
@@ -152,10 +152,9 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||||
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, payload)
|
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -176,15 +176,6 @@ type MockForwarder struct {
|
|||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
|
||||||
args := m.Called(origin)
|
|
||||||
return args.Get(0).([]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) {
|
|
||||||
m.Called(dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||||
m.Called(dst, src)
|
m.Called(dst, src)
|
||||||
}
|
}
|
||||||
@@ -221,8 +212,8 @@ func (m *MockForwarder) Listener() net.Listener {
|
|||||||
return args.Get(0).(net.Listener)
|
return args.Get(0).(net.Listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
args := m.Called(ctx, payload)
|
args := m.Called(ctx, origin)
|
||||||
if args.Get(0) == nil {
|
if args.Get(0) == nil {
|
||||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
}
|
}
|
||||||
@@ -413,7 +404,7 @@ func TestHandler(t *testing.T) {
|
|||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -437,7 +428,7 @@ func TestHandler(t *testing.T) {
|
|||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
mockSSHChannel.On("Close").Return(nil)
|
||||||
@@ -469,7 +460,7 @@ func TestHandler(t *testing.T) {
|
|||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
mockSSHChannel.On("Close").Return(nil)
|
||||||
@@ -577,7 +568,7 @@ func TestHandler(t *testing.T) {
|
|||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
mockSSHChannel.On("Close").Return(nil)
|
||||||
@@ -640,7 +631,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
||||||
@@ -666,7 +657,7 @@ func TestHandler(t *testing.T) {
|
|||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Run(func(args mock.Arguments) {
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
ctx := args.Get(0).(context.Context)
|
ctx := args.Get(0).(context.Context)
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
|
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ type tcp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Forwarder interface {
|
type Forwarder interface {
|
||||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
||||||
OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
|
||||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,10 +54,9 @@ func (tt *tcp) handleTcp(conn net.Conn) {
|
|||||||
log.Printf("Failed to close connection: %v", err)
|
log.Printf("Failed to close connection: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, payload)
|
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -72,10 +72,8 @@ func TestTCPServer_Serve_Success(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
port := listener.Addr().(*net.TCPAddr).Port
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
payload := []byte("test-payload")
|
|
||||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
|
||||||
reqs := make(chan *ssh.Request)
|
reqs := make(chan *ssh.Request)
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
||||||
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -99,12 +97,9 @@ func TestTCPServer_handleTcp_Success(t *testing.T) {
|
|||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
payload := []byte("test-payload")
|
|
||||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
|
||||||
|
|
||||||
reqs := make(chan *ssh.Request)
|
reqs := make(chan *ssh.Request)
|
||||||
mockChannel := new(MockSSHChannel)
|
mockChannel := new(MockSSHChannel)
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
||||||
|
|
||||||
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
||||||
|
|
||||||
@@ -121,9 +116,7 @@ func TestTCPServer_handleTcp_CloseError(t *testing.T) {
|
|||||||
mc.On("Close").Return(errors.New("close error"))
|
mc.On("Close").Return(errors.New("close error"))
|
||||||
mc.On("RemoteAddr").Return(&net.TCPAddr{})
|
mc.On("RemoteAddr").Return(&net.TCPAddr{})
|
||||||
|
|
||||||
payload := []byte("test-payload")
|
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
|
||||||
|
|
||||||
srv.handleTcp(mc)
|
srv.handleTcp(mc)
|
||||||
mc.AssertExpectations(t)
|
mc.AssertExpectations(t)
|
||||||
@@ -136,9 +129,7 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
|
|||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
payload := []byte("test-payload")
|
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
|
||||||
|
|
||||||
srv.handleTcp(serverConn)
|
srv.handleTcp(serverConn)
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,7 @@ type Forwarder interface {
|
|||||||
TunnelType() types.TunnelType
|
TunnelType() types.TunnelType
|
||||||
ForwardedPort() uint16
|
ForwardedPort() uint16
|
||||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
||||||
OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
|
||||||
WriteBadGatewayResponse(dst io.Writer)
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
type forwarder struct {
|
type forwarder struct {
|
||||||
@@ -60,7 +58,8 @@ func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64,
|
|||||||
return io.CopyBuffer(dst, src, buf)
|
return io.CopyBuffer(dst, src, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
|
payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
|
||||||
type channelResult struct {
|
type channelResult struct {
|
||||||
channel ssh.Channel
|
channel ssh.Channel
|
||||||
reqs <-chan *ssh.Request
|
reqs <-chan *ssh.Request
|
||||||
@@ -171,14 +170,6 @@ func (f *forwarder) Listener() net.Listener {
|
|||||||
return f.listener
|
return f.listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
|
||||||
_, err := dst.Write(types.BadGatewayResponse)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to write Bad Gateway response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *forwarder) Close() error {
|
func (f *forwarder) Close() error {
|
||||||
if f.Listener() != nil {
|
if f.Listener() != nil {
|
||||||
return f.listener.Close()
|
return f.listener.Close()
|
||||||
@@ -186,7 +177,7 @@ func (f *forwarder) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte {
|
||||||
host, portStr, _ := net.SplitHostPort(origin.String())
|
host, portStr, _ := net.SplitHostPort(origin.String())
|
||||||
port, _ := strconv.Atoi(portStr)
|
port, _ := strconv.Atoi(portStr)
|
||||||
|
|
||||||
@@ -197,7 +188,7 @@ func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
|||||||
OriginPort uint32
|
OriginPort uint32
|
||||||
}{
|
}{
|
||||||
DestAddr: "localhost",
|
DestAddr: "localhost",
|
||||||
DestPort: uint32(f.ForwardedPort()),
|
DestPort: uint32(destPort),
|
||||||
OriginAddr: host,
|
OriginAddr: host,
|
||||||
OriginPort: uint32(port),
|
OriginPort: uint32(port),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,10 +29,6 @@ func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
|||||||
return args.Get(0).([]byte)
|
return args.Get(0).([]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) {
|
|
||||||
m.Called(dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||||
m.Called(dst, src)
|
m.Called(dst, src)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user