From e534972abc32ba07dd71d04ed4c3a1f476b86a44 Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 22 Jan 2026 13:27:25 +0700 Subject: [PATCH] test(random): add unit tests for random behavior - Added unit tests to cover random string generation and error handling. - Introduced Random interface and random struct for better abstraction. - Updated server, session, and interaction packages to require Random interface for dependency injection. --- internal/random/random.go | 29 ++++++++++++-- internal/random/random_test.go | 61 ++++++++++++++++++++++++++++++ main.go | 7 ++-- server/server.go | 7 +++- session/interaction/interaction.go | 6 ++- session/interaction/model.go | 2 + session/interaction/slug.go | 3 +- session/session.go | 8 ++-- 8 files changed, 109 insertions(+), 14 deletions(-) create mode 100644 internal/random/random_test.go diff --git a/internal/random/random.go b/internal/random/random.go index 929cc7b..cb793d4 100644 --- a/internal/random/random.go +++ b/internal/random/random.go @@ -1,12 +1,35 @@ package random -import "crypto/rand" +import ( + "crypto/rand" + "fmt" + "io" +) -func GenerateRandomString(length int) (string, error) { +var ( + ErrInvalidLength = fmt.Errorf("invalid length") +) + +type Random interface { + String(length int) (string, error) +} + +type random struct { + reader io.Reader +} + +func New() Random { + return &random{reader: rand.Reader} +} + +func (ran *random) String(length int) (string, error) { + if length < 0 { + return "", ErrInvalidLength + } const charset = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, length) - if _, err := rand.Read(b); err != nil { + if _, err := ran.reader.Read(b); err != nil { return "", err } diff --git a/internal/random/random_test.go b/internal/random/random_test.go new file mode 100644 index 0000000..8c6787e --- /dev/null +++ b/internal/random/random_test.go @@ -0,0 +1,61 @@ +package random + +import ( + "errors" + "fmt" + "testing" +) + +type brainrotReader struct { + err error +} + +func (f *brainrotReader) Read(p []byte) (int, error) { + return 0, f.err +} + +func TestRandom_String(t *testing.T) { + tests := []struct { + name string + length int + wantErr bool + }{ + {"ValidLengthZero", 0, false}, + {"ValidPositiveLength", 10, false}, + {"NegativeLength", -1, true}, + {"VeryLargeLength", 1_000_000, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randomizer := New() + + result, err := randomizer.String(tt.length) + if (err != nil) != tt.wantErr { + t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && len(result) != tt.length { + t.Errorf("String() length = %v, want %v", len(result), tt.length) + } + }) + } +} + +func TestRandomWithFailingReader_String(t *testing.T) { + var randomizer Random + var errBrainrot = fmt.Errorf("you are not sigma enough") + randomizer = &random{reader: &brainrotReader{err: errBrainrot}} + t.Run("test failing reader", func(t *testing.T) { + result, err := randomizer.String(20) + if !errors.Is(err, errBrainrot) { + t.Errorf("String() error = %v, wantErr %v", err, errBrainrot) + return + } + + if result != "" { + t.Errorf("String() result = %v, want an empty string due to error", result) + } + }) +} diff --git a/main.go b/main.go index f897b46..d62f722 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,7 @@ import ( "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/key" "tunnel_pls/internal/port" + "tunnel_pls/internal/random" "tunnel_pls/internal/registry" "tunnel_pls/internal/transport" "tunnel_pls/internal/version" @@ -127,17 +128,17 @@ func main() { } }() } - + portManager := port.New() err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()) if err != nil { log.Fatalf("Failed to initialize port manager: %s", err) return } - + randomizer := random.New() var app server.Server go func() { - app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort()) + app, err = server.New(randomizer, conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort()) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return diff --git a/server/server.go b/server/server.go index f47c579..a1990b4 100644 --- a/server/server.go +++ b/server/server.go @@ -10,6 +10,7 @@ import ( "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/port" + "tunnel_pls/internal/random" "tunnel_pls/internal/registry" "tunnel_pls/session" @@ -21,6 +22,7 @@ type Server interface { Close() error } type server struct { + randomizer random.Random config config.Config sshPort string sshListener net.Listener @@ -30,13 +32,14 @@ type server struct { portRegistry port.Port } -func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { +func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) if err != nil { return nil, err } return &server{ + randomizer: randomizer, config: config, sshPort: sshPort, sshListener: listener, @@ -95,7 +98,7 @@ func (s *server) handleConnection(conn net.Conn) { cancel() } log.Println("SSH connection established:", sshConn.User()) - sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) + sshSession := session.New(s.randomizer, s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err) diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 5f68102..fe5b496 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -4,6 +4,7 @@ import ( "context" "log" "tunnel_pls/internal/config" + "tunnel_pls/internal/random" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -39,6 +40,7 @@ type Forwarder interface { type CloseFunc func() error type interaction struct { + randomizer random.Random config config.Config channel ssh.Channel slug slug.Slug @@ -76,9 +78,10 @@ func (i *interaction) SetWH(w, h int) { } } -func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { +func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { ctx, cancel := context.WithCancel(context.Background()) return &interaction{ + randomizer: randomizer, config: config, channel: nil, slug: slug, @@ -210,6 +213,7 @@ func (i *interaction) Start() { ti.Width = 50 m := &model{ + randomizer: i.randomizer, domain: i.config.Domain(), protocol: protocol, tunnelType: tunnelType, diff --git a/session/interaction/model.go b/session/interaction/model.go index 189b0a1..3002d16 100644 --- a/session/interaction/model.go +++ b/session/interaction/model.go @@ -3,6 +3,7 @@ package interaction import ( "fmt" "time" + "tunnel_pls/internal/random" "tunnel_pls/types" "github.com/charmbracelet/bubbles/help" @@ -22,6 +23,7 @@ func (i commandItem) Title() string { return i.name } func (i commandItem) Description() string { return i.desc } type model struct { + randomizer random.Random domain string protocol string tunnelType types.TunnelType diff --git a/session/interaction/slug.go b/session/interaction/slug.go index 2b871d4..647cd31 100644 --- a/session/interaction/slug.go +++ b/session/interaction/slug.go @@ -3,7 +3,6 @@ package interaction import ( "fmt" "strings" - "tunnel_pls/internal/random" "tunnel_pls/types" "github.com/charmbracelet/bubbles/key" @@ -47,7 +46,7 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, tea.Batch(tea.ClearScreen, textinput.Blink) default: if key.Matches(msg, m.keymap.random) { - newSubdomain, err := random.GenerateRandomString(20) + newSubdomain, err := m.randomizer.String(20) if err != nil { return m, cmd } diff --git a/session/session.go b/session/session.go index b1895ab..cb7c04c 100644 --- a/session/session.go +++ b/session/session.go @@ -37,6 +37,7 @@ type Session interface { } type session struct { + randomizer random.Random config config.Config initialReq <-chan *ssh.Request sshChan <-chan ssh.NewChannel @@ -49,13 +50,14 @@ type session struct { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { +func New(randomizer random.Random, config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { slugManager := slug.New() forwarderManager := forwarder.New(config, slugManager, conn) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) - interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) + interactionManager := interaction.New(randomizer, config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) return &session{ + randomizer: randomizer, config: config, initialReq: initialReq, sshChan: sshChan, @@ -346,7 +348,7 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error { } func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { - randomString, err := random.GenerateRandomString(20) + randomString, err := s.randomizer.String(20) if err != nil { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) }