refactor: remove duplicate channel management helpers from HTTP handler
SonarQube Scan / SonarQube Trigger (push) Successful in 2m12s
SonarQube Scan / SonarQube Trigger (push) Successful in 2m12s
This commit is contained in:
@@ -2,6 +2,7 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -19,8 +20,6 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var openChannelTimeout = 5 * time.Second
|
|
||||||
|
|
||||||
type httpHandler struct {
|
type httpHandler struct {
|
||||||
domain string
|
domain string
|
||||||
sessionRegistry registry.Registry
|
sessionRegistry registry.Registry
|
||||||
@@ -139,13 +138,17 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||||
channel, err := hh.openForwardedChannel(hw, sshSession)
|
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to establish channel: %v", err)
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
sshSession.Forwarder().WriteBadGatewayResponse(hw)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = channel.Close()
|
err = channel.Close()
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
@@ -162,51 +165,6 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ
|
|||||||
sshSession.Forwarder().HandleConnection(hw, channel)
|
sshSession.Forwarder().HandleConnection(hw, channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
|
|
||||||
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
|
|
||||||
|
|
||||||
type channelResult struct {
|
|
||||||
channel ssh.Channel
|
|
||||||
reqs <-chan *ssh.Request
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
resultChan := make(chan channelResult, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
|
||||||
select {
|
|
||||||
case resultChan <- channelResult{channel, reqs, err}:
|
|
||||||
default:
|
|
||||||
hh.cleanupUnusedChannel(channel, reqs)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case result := <-resultChan:
|
|
||||||
if result.err != nil {
|
|
||||||
return nil, result.err
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(result.reqs)
|
|
||||||
return result.channel, nil
|
|
||||||
case <-time.After(openChannelTimeout):
|
|
||||||
go func() {
|
|
||||||
result := <-resultChan
|
|
||||||
hh.cleanupUnusedChannel(result.channel, result.reqs)
|
|
||||||
}()
|
|
||||||
return nil, errors.New("timeout opening forwarded-tcpip channel")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
|
|
||||||
if channel != nil {
|
|
||||||
if err := channel.Close(); err != nil {
|
|
||||||
log.Printf("Failed to close unused channel: %v", err)
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
|
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
|
||||||
fingerprintMiddleware := middleware.NewTunnelFingerprint()
|
fingerprintMiddleware := middleware.NewTunnelFingerprint()
|
||||||
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
|
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -220,8 +221,8 @@ func (m *MockForwarder) Listener() net.Listener {
|
|||||||
return args.Get(0).(net.Listener)
|
return args.Get(0).(net.Listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
args := m.Called(payload)
|
args := m.Called(ctx, payload)
|
||||||
if args.Get(0) == nil {
|
if args.Get(0) == nil {
|
||||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
}
|
}
|
||||||
@@ -358,12 +359,10 @@ func TestHandler(t *testing.T) {
|
|||||||
isTLS: true,
|
isTLS: true,
|
||||||
redirectTLS: false,
|
redirectTLS: false,
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
expected: []byte(""),
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
setupMocks: func(msr *MockSessionRegistry) {
|
||||||
mockSession := new(MockSession)
|
mockSession := new(MockSession)
|
||||||
mockForwarder := new(MockForwarder)
|
mockForwarder := new(MockForwarder)
|
||||||
mockLifecycle := new(MockLifecycle)
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
|
|
||||||
msr.On("Get", types.SessionKey{
|
msr.On("Get", types.SessionKey{
|
||||||
Id: "test",
|
Id: "test",
|
||||||
@@ -371,15 +370,9 @@ func TestHandler(t *testing.T) {
|
|||||||
}).Return(mockSession, nil)
|
}).Return(mockSession, nil)
|
||||||
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
|
||||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
w := args.Get(0).(io.Writer)
|
|
||||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
|
||||||
})
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -391,8 +384,6 @@ func TestHandler(t *testing.T) {
|
|||||||
setupMocks: func(msr *MockSessionRegistry) {
|
setupMocks: func(msr *MockSessionRegistry) {
|
||||||
mockSession := new(MockSession)
|
mockSession := new(MockSession)
|
||||||
mockForwarder := new(MockForwarder)
|
mockForwarder := new(MockForwarder)
|
||||||
mockLifecycle := new(MockLifecycle)
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
mockSSHChannel := new(MockSSHChannel)
|
||||||
|
|
||||||
msr.On("Get", types.SessionKey{
|
msr.On("Get", types.SessionKey{
|
||||||
@@ -401,13 +392,11 @@ func TestHandler(t *testing.T) {
|
|||||||
}).Return(mockSession, nil)
|
}).Return(mockSession, nil)
|
||||||
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
|
||||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
mockSSHChannel.On("Close").Return(nil)
|
||||||
@@ -427,8 +416,6 @@ func TestHandler(t *testing.T) {
|
|||||||
setupMocks: func(msr *MockSessionRegistry) {
|
setupMocks: func(msr *MockSessionRegistry) {
|
||||||
mockSession := new(MockSession)
|
mockSession := new(MockSession)
|
||||||
mockForwarder := new(MockForwarder)
|
mockForwarder := new(MockForwarder)
|
||||||
mockLifecycle := new(MockLifecycle)
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
mockSSHChannel := new(MockSSHChannel)
|
||||||
|
|
||||||
msr.On("Get", types.SessionKey{
|
msr.On("Get", types.SessionKey{
|
||||||
@@ -437,13 +424,11 @@ func TestHandler(t *testing.T) {
|
|||||||
}).Return(mockSession, nil)
|
}).Return(mockSession, nil)
|
||||||
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
|
||||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
mockSSHChannel.On("Close").Return(nil)
|
||||||
@@ -524,18 +509,15 @@ func TestHandler(t *testing.T) {
|
|||||||
setupMocks: func(msr *MockSessionRegistry) {
|
setupMocks: func(msr *MockSessionRegistry) {
|
||||||
mockSession := new(MockSession)
|
mockSession := new(MockSession)
|
||||||
mockForwarder := new(MockForwarder)
|
mockForwarder := new(MockForwarder)
|
||||||
mockLifecycle := new(MockLifecycle)
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
mockSSHChannel := new(MockSSHChannel)
|
||||||
|
|
||||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
|
||||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
mockSSHChannel.On("Close").Return(nil)
|
||||||
@@ -560,19 +542,16 @@ func TestHandler(t *testing.T) {
|
|||||||
setupMocks: func(msr *MockSessionRegistry) {
|
setupMocks: func(msr *MockSessionRegistry) {
|
||||||
mockSession := new(MockSession)
|
mockSession := new(MockSession)
|
||||||
mockForwarder := new(MockForwarder)
|
mockForwarder := new(MockForwarder)
|
||||||
|
mockSSHChannel := new(MockSSHChannel)
|
||||||
|
|
||||||
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
||||||
return k.Id == "test"
|
return k.Id == "test"
|
||||||
})).Return(mockSession, nil)
|
})).Return(mockSession, nil)
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
|
|
||||||
mockLifecycle := new(MockLifecycle)
|
|
||||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
mockSSHChannel.On("Close").Return(nil)
|
||||||
},
|
},
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
setupConn: func() (net.Conn, net.Conn) {
|
||||||
@@ -592,18 +571,14 @@ func TestHandler(t *testing.T) {
|
|||||||
setupMocks: func(msr *MockSessionRegistry) {
|
setupMocks: func(msr *MockSessionRegistry) {
|
||||||
mockSession := new(MockSession)
|
mockSession := new(MockSession)
|
||||||
mockForwarder := new(MockForwarder)
|
mockForwarder := new(MockForwarder)
|
||||||
mockLifecycle := new(MockLifecycle)
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
mockSSHChannel := new(MockSSHChannel)
|
||||||
|
|
||||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
|
||||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
reqCh := make(chan *ssh.Request)
|
reqCh := make(chan *ssh.Request)
|
||||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
||||||
@@ -619,38 +594,20 @@ func TestHandler(t *testing.T) {
|
|||||||
isTLS: true,
|
isTLS: true,
|
||||||
redirectTLS: false,
|
redirectTLS: false,
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
expected: []byte(""),
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
setupMocks: func(msr *MockSessionRegistry) {
|
||||||
oldTimeout := openChannelTimeout
|
|
||||||
openChannelTimeout = 10 * time.Millisecond
|
|
||||||
t.Cleanup(func() {
|
|
||||||
openChannelTimeout = oldTimeout
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
})
|
|
||||||
|
|
||||||
mockSession := new(MockSession)
|
mockSession := new(MockSession)
|
||||||
mockForwarder := new(MockForwarder)
|
mockForwarder := new(MockForwarder)
|
||||||
mockLifecycle := new(MockLifecycle)
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
|
||||||
|
|
||||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
mockSession.On("Forwarder").Return(mockForwarder)
|
||||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
|
||||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
w := args.Get(0).(io.Writer)
|
|
||||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
|
||||||
})
|
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
mockForwarder.On("OpenForwardedChannel", mock.Anything, []byte("payload")).Run(func(args mock.Arguments) {
|
||||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) {
|
ctx := args.Get(0).(context.Context)
|
||||||
time.Sleep(50 * time.Millisecond)
|
<-ctx.Done()
|
||||||
}).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
|
||||||
|
|
||||||
mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error"))
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -740,7 +697,7 @@ func TestHandler(t *testing.T) {
|
|||||||
} else {
|
} else {
|
||||||
assert.Equal(t, string(tt.expected), string(response))
|
assert.Equal(t, string(tt.expected), string(response))
|
||||||
}
|
}
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(10 * time.Second):
|
||||||
if clientConn != nil {
|
if clientConn != nil {
|
||||||
t.Fatal("Test timeout - no response received")
|
t.Fatal("Test timeout - no response received")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package transport
|
package transport
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
@@ -17,7 +19,7 @@ type tcp struct {
|
|||||||
|
|
||||||
type Forwarder interface {
|
type Forwarder interface {
|
||||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||||
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
||||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +56,9 @@ func (tt *tcp) handleTcp(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||||
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
|
defer cancel()
|
||||||
|
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func TestTCPServer_Serve_Success(t *testing.T) {
|
|||||||
payload := []byte("test-payload")
|
payload := []byte("test-payload")
|
||||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
||||||
reqs := make(chan *ssh.Request)
|
reqs := make(chan *ssh.Request)
|
||||||
mf.On("OpenForwardedChannel", payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
||||||
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -104,7 +104,7 @@ func TestTCPServer_handleTcp_Success(t *testing.T) {
|
|||||||
|
|
||||||
reqs := make(chan *ssh.Request)
|
reqs := make(chan *ssh.Request)
|
||||||
mockChannel := new(MockSSHChannel)
|
mockChannel := new(MockSSHChannel)
|
||||||
mf.On("OpenForwardedChannel", payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
||||||
|
|
||||||
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ func TestTCPServer_handleTcp_CloseError(t *testing.T) {
|
|||||||
|
|
||||||
payload := []byte("test-payload")
|
payload := []byte("test-payload")
|
||||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
||||||
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||||
|
|
||||||
srv.handleTcp(mc)
|
srv.handleTcp(mc)
|
||||||
mc.AssertExpectations(t)
|
mc.AssertExpectations(t)
|
||||||
@@ -138,7 +138,7 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
|
|||||||
|
|
||||||
payload := []byte("test-payload")
|
payload := []byte("test-payload")
|
||||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
||||||
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
mf.On("OpenForwardedChannel", mock.Anything, payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||||
|
|
||||||
srv.handleTcp(serverConn)
|
srv.handleTcp(serverConn)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestTransportInterface(t *testing.T) {
|
|
||||||
var _ Transport = (*httpServer)(nil)
|
|
||||||
var _ Transport = (*https)(nil)
|
|
||||||
var _ Transport = (*tcp)(nil)
|
|
||||||
|
|
||||||
assert.True(t, true)
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -8,7 +9,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/types"
|
||||||
@@ -25,7 +25,7 @@ type Forwarder interface {
|
|||||||
ForwardedPort() uint16
|
ForwardedPort() uint16
|
||||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||||
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
||||||
WriteBadGatewayResponse(dst io.Writer)
|
WriteBadGatewayResponse(dst io.Writer)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
@@ -60,7 +60,7 @@ func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64,
|
|||||||
return io.CopyBuffer(dst, src, buf)
|
return io.CopyBuffer(dst, src, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
func (f *forwarder) OpenForwardedChannel(ctx context.Context, payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
type channelResult struct {
|
type channelResult struct {
|
||||||
channel ssh.Channel
|
channel ssh.Channel
|
||||||
reqs <-chan *ssh.Request
|
reqs <-chan *ssh.Request
|
||||||
@@ -72,7 +72,7 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
|
|||||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||||
select {
|
select {
|
||||||
case resultChan <- channelResult{channel, reqs, err}:
|
case resultChan <- channelResult{channel, reqs, err}:
|
||||||
default:
|
case <-ctx.Done():
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
err = channel.Close()
|
err = channel.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -87,8 +87,8 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
|
|||||||
select {
|
select {
|
||||||
case result := <-resultChan:
|
case result := <-resultChan:
|
||||||
return result.channel, result.reqs, result.err
|
return result.channel, result.reqs, result.err
|
||||||
case <-time.After(5 * time.Second):
|
case <-ctx.Done():
|
||||||
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
|
return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user