package session import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/binary" "encoding/pem" "fmt" "net" "strconv" "strings" "testing" "time" "tunnel_pls/internal/config" portUtil "tunnel_pls/internal/port" "tunnel_pls/internal/registry" "tunnel_pls/session/lifecycle" "tunnel_pls/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" ) type mockRandom struct { mock.Mock } func (m *mockRandom) String(length int) (string, error) { args := m.Called(length) return args.String(0), args.Error(1) } type mockConfig struct { mock.Mock config.Config } func (m *mockConfig) Domain() string { return m.Called().String(0) } func (m *mockConfig) SSHPort() string { return m.Called().String(0) } func (m *mockConfig) Mode() types.ServerMode { args := m.Called() if args.Get(0) == nil { return 0 } switch v := args.Get(0).(type) { case types.ServerMode: return v case int: return types.ServerMode(v) default: return types.ServerMode(args.Int(0)) } } func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) } type mockRegistry struct { mock.Mock registry.Registry removedKey types.SessionKey } func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool { return m.Called(key, session).Bool(0) } func (m *mockRegistry) Remove(key types.SessionKey) { m.removedKey = key } type mockPort struct { mock.Mock } func (m *mockPort) AddRange(startPort, endPort uint16) error { return m.Called(startPort, endPort).Error(0) } func (m *mockPort) Unassigned() (uint16, bool) { args := m.Called() var port uint16 if args.Get(0) != nil { switch v := args.Get(0).(type) { case int: port = uint16(v) case uint16: port = v case uint32: port = uint16(v) case int32: port = uint16(v) case float64: port = uint16(v) default: port = uint16(args.Int(0)) } } return port, args.Bool(1) } func (m *mockPort) SetStatus(port uint16, assigned bool) error { return m.Called(port, assigned).Error(0) } func (m *mockPort) Claim(port uint16) bool { return m.Called(port).Bool(0) } type mockSSHConn struct { ssh.Conn mock.Mock } func (m *mockSSHConn) Wait() error { return m.Called().Error(0) } func (m *mockSSHConn) Close() error { return m.Called().Error(0) } func (m *mockSSHConn) User() string { return m.Called().String(0) } type mockSSHChannel struct { ssh.Channel mock.Mock } func (m *mockSSHChannel) Close() error { return m.Called().Error(0) } type mockNewChannel struct { ssh.NewChannel mock.Mock } func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) { args := m.Called() return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) } func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, sChans <-chan ssh.NewChannel, cConn ssh.Conn, cleanup func()) { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) key, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) privDER := x509.MarshalPKCS1PrivateKey(key) privBlock := pem.Block{ Type: "RSA PRIVATE KEY", Bytes: privDER, } pk, err := ssh.ParsePrivateKey(pem.EncodeToMemory(&privBlock)) require.NoError(t, err) sCfg := &ssh.ServerConfig{ NoClientAuth: true, } sCfg.AddHostKey(pk) cCfg := &ssh.ClientConfig{ User: "test", HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } var sConnObj *ssh.ServerConn var sChansChan <-chan ssh.NewChannel var sReqsChan <-chan *ssh.Request errChan := make(chan error, 1) go func() { conn, err := l.Accept() if err != nil { errChan <- err return } sConnObj, sChansChan, sReqsChan, err = ssh.NewServerConn(conn, sCfg) errChan <- err }() conn, err := net.Dial("tcp", l.Addr().String()) require.NoError(t, err) cConnObj, cChans, cReqs, err := ssh.NewClientConn(conn, "pipe", cCfg) require.NoError(t, err) go ssh.DiscardRequests(cReqs) go func() { for newChan := range cChans { if newChan.ChannelType() == "session" { continue } newChan.Reject(ssh.Prohibited, "") } }() select { case err := <-errChan: require.NoError(t, err) case <-time.After(5 * time.Second): t.Fatal("SSH handshake timed out") } return sConnObj, sReqsChan, sChansChan, cConnObj, func() { cConnObj.Close() sConnObj.Close() l.Close() } } func TestNew(t *testing.T) { conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: &ssh.ServerConn{}, InitialReq: make(chan *ssh.Request), SshChan: make(chan ssh.NewChannel), SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf) assert.NotNil(t, s) assert.NotNil(t, s.Lifecycle()) assert.NotNil(t, s.Interaction()) assert.NotNil(t, s.Forwarder()) assert.NotNil(t, s.Slug()) } func TestDetail(t *testing.T) { conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: &ssh.ServerConn{}, InitialReq: make(chan *ssh.Request), SshChan: make(chan ssh.NewChannel), SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) s.forwarder.SetType(types.TunnelTypeHTTP) s.slug.Set("test-slug") s.lifecycle.SetStatus(types.SessionStatusRUNNING) detail := s.Detail() assert.Equal(t, "HTTP", detail.ForwardingType) assert.Equal(t, "test-slug", detail.Slug) assert.Equal(t, "testuser", detail.UserID) assert.True(t, detail.Active) s.forwarder.SetType(types.TunnelTypeTCP) detail = s.Detail() assert.Equal(t, "TCP", detail.ForwardingType) s.forwarder.SetType(types.TunnelTypeUNKNOWN) detail = s.Detail() assert.Equal(t, "UNKNOWN", detail.ForwardingType) } func TestIsBlockedPort(t *testing.T) { tests := []struct { port uint16 expected bool }{ {80, false}, {443, false}, {22, true}, {1023, true}, {1024, false}, {1080, true}, {3306, true}, {8080, true}, {0, false}, } for _, tt := range tests { t.Run(fmt.Sprintf("Port %d", tt.port), func(t *testing.T) { assert.Equal(t, tt.expected, isBlockedPort(tt.port)) }) } } func TestReadSSHString(t *testing.T) { tests := []struct { name string input []byte want string wantErr bool }{ { name: "valid string", input: append([]byte{0, 0, 0, 4}, []byte("test")...), want: "test", wantErr: false, }, { name: "empty string", input: []byte{0, 0, 0, 0}, want: "", wantErr: false, }, { name: "short length", input: []byte{0, 0, 0}, want: "", wantErr: true, }, { name: "missing payload", input: []byte{0, 0, 0, 4, 'a', 'b'}, want: "", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := readSSHString(bytes.NewReader(tt.input)) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.want, got) } }) } } func TestHandleGlobalRequest(t *testing.T) { _, sReqs, _, cConn, cleanup := setupSSH(t) defer cleanup() conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: &ssh.ServerConn{}, InitialReq: make(chan *ssh.Request), SshChan: make(chan ssh.NewChannel), SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) done := make(chan struct{}) go func() { _ = s.HandleGlobalRequest(sReqs) close(done) }() tests := []struct { name string reqType string payload []byte wantReply bool expected bool }{ {"shell", "shell", nil, true, true}, {"pty-req", "pty-req", nil, true, true}, {"window-change valid", "window-change", make([]byte, 16), true, true}, {"window-change invalid", "window-change", make([]byte, 4), true, false}, {"unknown", "unknown", nil, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload) assert.NoError(t, err) assert.Equal(t, tt.expected, ok) }) } cConn.Close() select { case <-done: case <-time.After(2 * time.Second): t.Fatal("HandleGlobalRequest timed out after cConn.Close()") } } func TestHandleTCPIPForward_Table(t *testing.T) { setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { sConn, sReqs, _, cConn, cleanup := setupSSH(t) mRegistry := &mockRegistry{} mPort := &mockPort{} mRandom := &mockRandom{} conf := &Config{ Randomizer: mRandom, Config: &mockConfig{}, Conn: sConn, InitialReq: make(chan *ssh.Request), SshChan: make(chan ssh.NewChannel), SessionRegistry: mRegistry, PortRegistry: mPort, User: "testuser", } s := New(conf).(*session) return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup } t.Run("HTTP Forward Success", func(t *testing.T) { s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mRandom.On("String", 20).Return("test-slug-1234567890", nil) mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 80) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() var req *ssh.Request select { case req = <-sReqs: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for tcpip-forward request") } err := s.HandleTCPIPForward(req) assert.NoError(t, err) assert.Equal(t, "test-slug-1234567890", s.slug.String()) }) t.Run("TCP Forward Success", func(t *testing.T) { s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(true) mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 0) mPort.On("Unassigned").Return(uint16(12345), true) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() var req *ssh.Request select { case req = <-sReqs: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for tcpip-forward request") } err := s.HandleTCPIPForward(req) assert.NoError(t, err) assert.Equal(t, uint16(12345), s.forwarder.ForwardedPort()) }) t.Run("Invalid Payload", func(t *testing.T) { s, _, _, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() payload := []byte{0, 0, 0} go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() var req *ssh.Request select { case req = <-sReqs: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for tcpip-forward request") } err := s.HandleTCPIPForward(req) assert.Error(t, err) }) t.Run("Blocked Port", func(t *testing.T) { s, _, _, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 22) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() var req *ssh.Request select { case req = <-sReqs: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for tcpip-forward request") } err := s.HandleTCPIPForward(req) assert.Error(t, err) }) } func TestStart_Table(t *testing.T) { setup := func(t *testing.T) (*session, *Config, ssh.Conn, func()) { sConn, sReqs, sChans, cConn, cleanup := setupSSH(t) mRegistry := &mockRegistry{} mPort := &mockPort{} mRandom := &mockRandom{} mConfig := &mockConfig{} mConfig.On("Mode").Return(types.ServerModeSTANDALONE) mConfig.On("Domain").Return("example.com") mConfig.On("SSHPort").Return("2222") conf := &Config{ Randomizer: mRandom, Config: mConfig, Conn: sConn, InitialReq: sReqs, SshChan: sChans, SessionRegistry: mRegistry, PortRegistry: mPort, User: "testuser", } s := New(conf).(*session) return s, conf, cConn, cleanup } t.Run("Full Success TCP", func(t *testing.T) { s, conf, cConn, cleanup := setup(t) defer cleanup() payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 0) conf.PortRegistry.(*mockPort).On("Claim", mock.Anything).Return(true) conf.PortRegistry.(*mockPort).On("Unassigned").Return(uint16(0), true) conf.PortRegistry.(*mockPort).On("SetStatus", mock.AnythingOfType("uint16"), mock.Anything).Return(nil) conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true) conf.Config.(*mockConfig).On("TLSEnabled").Return(false) go func() { time.Sleep(200 * time.Millisecond) ch, reqs, err := cConn.OpenChannel("session", nil) if err == nil { go ssh.DiscardRequests(reqs) time.Sleep(200 * time.Millisecond) _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) time.Sleep(200 * time.Millisecond) ch.Write([]byte("q")) time.Sleep(100 * time.Millisecond) ch.Close() } cConn.Close() }() err := s.Start() assert.NoError(t, err) }) t.Run("Headless mode success", func(t *testing.T) { s, conf, cConn, cleanup := setup(t) defer cleanup() payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 80) conf.Randomizer.(*mockRandom).On("String", 20).Return("headless-slug", nil) conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true) go func() { time.Sleep(600 * time.Millisecond) _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) time.Sleep(100 * time.Millisecond) cConn.Close() }() err := s.Start() assert.NoError(t, err) }) t.Run("Missing Forward Request", func(t *testing.T) { s, _, cConn, cleanup := setup(t) defer cleanup() go func() { time.Sleep(1200 * time.Millisecond) cConn.Close() }() err := s.Start() assert.Error(t, err) assert.Contains(t, err.Error(), "no forwarding Request") }) t.Run("Unauthorized Headless", func(t *testing.T) { s, conf, cConn, cleanup := setup(t) defer cleanup() conf.User = "UNAUTHORIZED" s = New(conf).(*session) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 80) go func() { time.Sleep(600 * time.Millisecond) _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() err := s.Start() assert.Error(t, err) }) } func TestForwardingFailures(t *testing.T) { setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { sConn, sReqs, _, cConn, cleanup := setupSSH(t) mRegistry := &mockRegistry{} mPort := &mockPort{} mRandom := &mockRandom{} conf := &Config{ Randomizer: mRandom, Config: &mockConfig{}, Conn: sConn, InitialReq: make(chan *ssh.Request), SshChan: make(chan ssh.NewChannel), SessionRegistry: mRegistry, PortRegistry: mPort, User: "testuser", } s := New(conf).(*session) return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup } t.Run("HTTP Registration Failed", func(t *testing.T) { s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mRandom.On("String", 20).Return("test-slug", nil) mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 80) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() var req *ssh.Request select { case req = <-sReqs: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for tcpip-forward request") } err := s.HandleTCPIPForward(req) assert.Error(t, err) }) t.Run("TCP Port Claim Failed", func(t *testing.T) { s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(false) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 1234) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() var req *ssh.Request select { case req = <-sReqs: case <-time.After(2 * time.Second): t.Fatal("timed out waiting for tcpip-forward request") } err := s.HandleTCPIPForward(req) assert.Error(t, err) }) t.Run("HTTP Randomizer Error", func(t *testing.T) { s, _, _, mRandom, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mRandom.On("String", 20).Return("", fmt.Errorf("random error")) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 80) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() req := <-sReqs err := s.HandleTCPIPForward(req) assert.Error(t, err) assert.Contains(t, err.Error(), "random error") }) t.Run("Port Registry No Port", func(t *testing.T) { s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Unassigned").Return(uint16(0), false) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 0) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() req := <-sReqs err := s.HandleTCPIPForward(req) assert.Error(t, err) assert.Contains(t, err.Error(), "no available port") }) t.Run("Port too large", func(t *testing.T) { s, _, _, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 70000) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() req := <-sReqs err := s.HandleTCPIPForward(req) assert.Error(t, err) assert.Contains(t, err.Error(), "port is larger than allowed") }) t.Run("TCP Registration Failed", func(t *testing.T) { s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(true) mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 1234) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() req := <-sReqs err := s.HandleTCPIPForward(req) assert.Error(t, err) assert.Contains(t, err.Error(), "Failed to register TunnelTypeTCP client") }) t.Run("Finalize Forwarding Failure", func(t *testing.T) { s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mRandom.On("String", 20).Return("test-slug", nil) mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], 80) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() req := <-sReqs cConn.Close() time.Sleep(50 * time.Millisecond) err := s.HandleTCPIPForward(req) assert.Error(t, err) }) t.Run("TCP Listen Failure", func(t *testing.T) { s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(true) mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) l, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Fatal(err) } defer l.Close() _, portStr, _ := net.SplitHostPort(l.Addr().String()) port, _ := strconv.Atoi(portStr) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) copy(payload[4:13], "localhost") binary.BigEndian.PutUint32(payload[13:17], uint32(port)) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) }() req := <-sReqs err = s.HandleTCPIPForward(req) assert.Error(t, err) assert.Contains(t, err.Error(), "is already in use or restricted") }) } func TestSetupInteractiveMode_Error(t *testing.T) { sConn, _, sChans, _, cleanup := setupSSH(t) defer cleanup() conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: sConn, InitialReq: make(chan *ssh.Request), SshChan: sChans, SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) mockChan := &mockNewChanFail{} err := s.setupInteractiveMode(mockChan) if err == nil { t.Error("expected error, got nil") } } type mockNewChanFail struct { ssh.NewChannel } func (m *mockNewChanFail) Accept() (ssh.Channel, <-chan *ssh.Request, error) { return nil, nil, fmt.Errorf("accept failed") } func TestWaitForTCPIPForward_EdgeCases(t *testing.T) { t.Run("Wrong Request Type", func(t *testing.T) { _, sReqs, _, cConn, cleanup := setupSSH(t) defer cleanup() s := &session{initialReq: sReqs} go func() { _, _, _ = cConn.SendRequest("not-tcpip-forward", true, nil) }() req := s.waitForTCPIPForward() if req != nil { t.Error("expected nil request") } }) t.Run("Channel Closed", func(t *testing.T) { initialReq := make(chan *ssh.Request) s := &session{initialReq: initialReq} close(initialReq) req := s.waitForTCPIPForward() if req != nil { t.Error("expected nil request") } }) } func TestSetupSessionMode_ChannelClosed(t *testing.T) { sshChan := make(chan ssh.NewChannel) s := &session{sshChan: sshChan} close(sshChan) err := s.setupSessionMode() if err != nil { t.Errorf("unexpected error: %v", err) } } func TestStart_SetupSessionModeError(t *testing.T) { sshChan := make(chan ssh.NewChannel, 1) conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: &ssh.ServerConn{}, InitialReq: make(chan *ssh.Request), SshChan: sshChan, SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) mockChan := &mockNewChanFail{} sshChan <- mockChan err := s.Start() if err == nil { t.Error("expected error, got nil") } } func TestWaitForSessionEnd_Error(t *testing.T) { mConn := &mockSSHConn{} mConn.On("Wait").Return(fmt.Errorf("wait error")) mConn.On("Close").Return(nil) mForwarder := &mockLifecycleForwarder{} mForwarder.On("TunnelType").Return(types.TunnelTypeTCP) mForwarder.On("ForwardedPort").Return(uint16(80)) mForwarder.On("Close").Return(fmt.Errorf("close error")) mSlug := &mockLifecycleSlug{} mSlug.On("String").Return("slug") mPort := &mockPort{} mPort.On("SetStatus", mock.Anything, mock.Anything).Return(nil) mRegistry := &mockRegistry{} mRegistry.On("Remove", mock.Anything).Return() l := lifecycle.New(mConn, mForwarder, mSlug, mPort, mRegistry, "testuser") s := &session{ lifecycle: l, } err := s.waitForSessionEnd() assert.Error(t, err) } type mockLifecycleForwarder struct { mock.Mock lifecycle.Forwarder } func (m *mockLifecycleForwarder) TunnelType() types.TunnelType { return m.Called().Get(0).(types.TunnelType) } func (m *mockLifecycleForwarder) ForwardedPort() uint16 { args := m.Called() if args.Get(0) == nil { return 0 } switch v := args.Get(0).(type) { case uint16: return v case uint32: return uint16(v) case uint64: return uint16(v) case uint8: return uint16(v) case uint: return uint16(v) case int: return uint16(v) case int8: return uint16(v) case int16: return uint16(v) case int32: return uint16(v) case int64: return uint16(v) case float32: return uint16(v) case float64: return uint16(v) default: return uint16(args.Int(0)) } } func (m *mockLifecycleForwarder) Close() error { return m.Called().Error(0) } type mockLifecycleSlug struct { mock.Mock } func (m *mockLifecycleSlug) String() string { return m.Called().String(0) } func (m *mockLifecycleSlug) Set(slug string) { m.Called(slug) } func TestHandleMissingForwardRequest(t *testing.T) { mConn := &mockSSHConn{} mConfig := &mockConfig{} mConfig.On("Domain").Return("example.com") mConfig.On("SSHPort").Return("2222") mConn.On("Close").Return(nil) conf := &Config{ Randomizer: &mockRandom{}, Config: mConfig, Conn: &ssh.ServerConn{Conn: mConn}, InitialReq: make(chan *ssh.Request), SshChan: make(chan ssh.NewChannel), SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) err := s.handleMissingForwardRequest() if err == nil { t.Error("expected error, got nil") } } func TestParseForwardPayload_Errors(t *testing.T) { s := &session{} t.Run("Short Address", func(t *testing.T) { _, _, err := s.parseForwardPayload(bytes.NewReader([]byte{0, 0, 0, 4})) if err == nil { t.Error("expected error, got nil") } }) t.Run("Short Port", func(t *testing.T) { payload := append([]byte{0, 0, 0, 4}, []byte("addr")...) _, _, err := s.parseForwardPayload(bytes.NewReader(payload)) if err == nil { t.Error("expected error, got nil") } }) t.Run("Blocked Port", func(t *testing.T) { payload := append([]byte{0, 0, 0, 4}, []byte("addr")...) portBuf := make([]byte, 4) binary.BigEndian.PutUint32(portBuf, 22) payload = append(payload, portBuf...) _, _, err := s.parseForwardPayload(bytes.NewReader(payload)) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "port is block") { t.Errorf("expected error to contain %q, got %q", "port is block", err.Error()) } }) } func TestDenyForwardingRequest_TunnelNotSetupYet(t *testing.T) { sConn, sReqs, _, cConn, cleanup := setupSSH(t) defer cleanup() mRegistry := &mockRegistry{} mPort := &mockPort{} mRandom := &mockRandom{} conf := &Config{ Randomizer: mRandom, Config: &mockConfig{}, Conn: sConn, InitialReq: sReqs, SshChan: make(chan ssh.NewChannel), SessionRegistry: mRegistry, PortRegistry: mPort, User: "testuser", } s := New(conf).(*session) go func() { _, _, _ = cConn.SendRequest("tcpip-forward", true, nil) }() var req *ssh.Request select { case req = <-sReqs: case <-time.After(time.Second): t.Fatal("timeout") } key := &types.SessionKey{Id: "", Type: types.TunnelTypeUNKNOWN} err := s.denyForwardingRequest(req, key, &mockCloser{}, "test error") if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "test error") { t.Errorf("expected error to contain %q, got %q", "test error", err.Error()) } assert.Equal(t, *key, mRegistry.removedKey) } func TestDenyForwardingRequest_Full(t *testing.T) { setup := func(t *testing.T) (*session, *mockRegistry, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { sConn, sReqs, _, cConn, cleanup := setupSSH(t) mRegistry := &mockRegistry{} conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: sConn, InitialReq: sReqs, SshChan: make(chan ssh.NewChannel), SessionRegistry: mRegistry, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) return s, mRegistry, sConn, sReqs, cConn, cleanup } getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request { go func() { _, _, _ = client.SendRequest("tcpip-forward", true, nil) }() select { case req, ok := <-serverReqs: if !ok { t.Fatal("channel closed") } return req case <-time.After(2 * time.Second): t.Fatal("timeout getting request") return nil } } t.Run("All Success", func(t *testing.T) { s, mRegistry, _, sReqs, cConn, cleanup := setup(t) defer cleanup() req := getReq(t, cConn, sReqs) key := &types.SessionKey{Id: "test", Type: types.TunnelTypeHTTP} s.slug.Set("test") s.forwarder.SetType(types.TunnelTypeHTTP) mCloser := &mockCloser{} err := s.denyForwardingRequest(req, key, mCloser, "error") if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "error") { t.Errorf("expected error to contain %q, got %q", "error", err.Error()) } assert.Equal(t, *key, mRegistry.removedKey) }) t.Run("Listener Close error", func(t *testing.T) { s, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() req := getReq(t, cConn, sReqs) mCloser := &mockCloser{err: fmt.Errorf("close error")} err := s.denyForwardingRequest(req, nil, mCloser, "error") if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "close listener: close error") { t.Errorf("expected error to contain %q, got %q", "close listener: close error", err.Error()) } }) t.Run("Reply error", func(t *testing.T) { s, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() req := getReq(t, cConn, sReqs) cConn.Close() time.Sleep(100 * time.Millisecond) err := s.denyForwardingRequest(req, nil, nil, "error") if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "reply request") { t.Errorf("expected error to contain %q, got %q", "reply request", err.Error()) } }) t.Run("Lifecycle Close error", func(t *testing.T) { s, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() req := getReq(t, cConn, sReqs) mLife := &mockLifecycle{closeErr: fmt.Errorf("life close error")} s.lifecycle = mLife err := s.denyForwardingRequest(req, nil, nil, "error") if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "close session: life close error") { t.Errorf("expected error to contain %q, got %q", "close session: life close error", err.Error()) } }) } func TestHandleTCPForward_Failures(t *testing.T) { setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { sConn, sReqs, _, cConn, cleanup := setupSSH(t) mRegistry := &mockRegistry{} mPort := &mockPort{} conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: sConn, InitialReq: sReqs, SshChan: make(chan ssh.NewChannel), SessionRegistry: mRegistry, PortRegistry: mPort, User: "testuser", } s := New(conf).(*session) return s, mRegistry, mPort, sConn, sReqs, cConn, cleanup } getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request { go func() { _, _, _ = client.SendRequest("tcpip-forward", true, nil) }() select { case req, ok := <-serverReqs: if !ok { t.Fatal("channel closed") } return req case <-time.After(2 * time.Second): t.Fatal("timeout getting request") return nil } } t.Run("Port Claim fail", func(t *testing.T) { s, _, mPort, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(false) err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 1234) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "already in use") { t.Errorf("expected error to contain %q, got %q", "already in use", err.Error()) } }) t.Run("Listen fail", func(t *testing.T) { s, _, mPort, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(true) l, err := net.Listen("tcp", "0.0.0.0:0") if err != nil { t.Fatal(err) } defer l.Close() port := uint16(l.Addr().(*net.TCPAddr).Port) err = s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", port) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "already in use") { t.Errorf("expected error to contain %q, got %q", "already in use", err.Error()) } }) t.Run("Registry Register fail", func(t *testing.T) { s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(true) mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 0) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "Failed to register") { t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error()) } }) t.Run("Finalize fail (Reply fail)", func(t *testing.T) { s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mPort.On("Claim", mock.Anything).Return(true) mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) req := getReq(t, cConn, sReqs) cConn.Close() time.Sleep(100 * time.Millisecond) err := s.HandleTCPForward(req, "localhost", 0) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "Failed to finalize forwarding") { t.Errorf("expected error to contain %q, got %q", "Failed to finalize forwarding", err.Error()) } }) } func TestHandleHTTPForward_Failures(t *testing.T) { setup := func(t *testing.T) (*session, *mockRegistry, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { sConn, sReqs, _, cConn, cleanup := setupSSH(t) mRegistry := &mockRegistry{} mRandom := &mockRandom{} s := New(&Config{ Randomizer: mRandom, Config: &mockConfig{}, Conn: sConn, InitialReq: sReqs, SshChan: make(chan ssh.NewChannel), SessionRegistry: mRegistry, PortRegistry: &mockPort{}, User: "testuser", }).(*session) return s, mRegistry, mRandom, sConn, sReqs, cConn, cleanup } getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request { go func() { _, _, _ = client.SendRequest("tcpip-forward", true, nil) }() return <-serverReqs } t.Run("Random fail", func(t *testing.T) { s, _, mRandom, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mRandom.On("String", 20).Return("", fmt.Errorf("random error")) err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "Failed to create slug") { t.Errorf("expected error to contain %q, got %q", "Failed to create slug", err.Error()) } }) t.Run("Register fail", func(t *testing.T) { s, mRegistry, mRandom, _, sReqs, cConn, cleanup := setup(t) defer cleanup() mRandom.On("String", 20).Return("slug", nil) mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "Failed to register") { t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error()) } }) } func TestHandleGlobalRequest_Failures(t *testing.T) { _, sReqs, _, cConn, cleanup := setupSSH(t) defer cleanup() conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: &ssh.ServerConn{}, InitialReq: make(chan *ssh.Request), SshChan: make(chan ssh.NewChannel), SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) done := make(chan struct{}) go func() { _ = s.HandleGlobalRequest(sReqs) close(done) }() tests := []struct { name string reqType string payload []byte wantReply bool expected bool }{ {"shell", "shell", nil, true, true}, {"pty-req", "pty-req", nil, true, true}, {"window-change valid", "window-change", make([]byte, 16), true, true}, {"window-change invalid", "window-change", make([]byte, 4), true, false}, {"unknown", "unknown", nil, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload) assert.NoError(t, err) assert.Equal(t, tt.expected, ok) }) } cConn.Close() select { case <-done: case <-time.After(2 * time.Second): t.Fatal("HandleGlobalRequest timed out after cConn.Close()") } } func TestSetupInteractiveMode_GlobalRequestError(t *testing.T) { sConn, _, sChans, _, cleanup := setupSSH(t) defer cleanup() conf := &Config{ Randomizer: &mockRandom{}, Config: &mockConfig{}, Conn: sConn, InitialReq: make(chan *ssh.Request), SshChan: sChans, SessionRegistry: &mockRegistry{}, PortRegistry: &mockPort{}, User: "testuser", } s := New(conf).(*session) mockChan := &mockNewChanFail{} err := s.setupInteractiveMode(mockChan) if err == nil { t.Error("expected error, got nil") } } type mockCloser struct { err error } func (m *mockCloser) Close() error { return m.err } type mockLifecycle struct { lifecycle.Lifecycle closeErr error conn ssh.Conn user string } func (m *mockLifecycle) Close() error { return m.closeErr } func (m *mockLifecycle) Connection() ssh.Conn { return m.conn } func (m *mockLifecycle) User() string { return m.user } func (m *mockLifecycle) IsActive() bool { return false } func (m *mockLifecycle) PortRegistry() portUtil.Port { return nil } func (m *mockLifecycle) SetChannel(ch ssh.Channel) {} func (m *mockLifecycle) SetStatus(status types.SessionStatus) {} func (m *mockLifecycle) StartedAt() time.Time { return time.Time{} }