From dbdf8094fa6a54b56fb6c1ae35a33cb8e7fa0e7c Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 14 Jan 2026 15:28:17 +0700 Subject: [PATCH] refactor: replace Get/Set patterns with idiomatic Go interfaces - rename constructors to New - remove Get/Set-style accessors - replace string-based enums with iota-backed types --- internal/grpc/client/client.go | 4 +- server/http.go | 12 +-- session/forwarder/forwarder.go | 60 ++++++------ session/handler.go | 12 +-- session/interaction/interaction.go | 62 ++++++------- session/lifecycle/lifecycle.go | 84 +++++++++-------- session/registry.go | 32 +++---- session/session.go | 144 ++++++++++++++++------------- session/slug/slug.go | 14 +-- types/types.go | 21 ++--- 10 files changed, 231 insertions(+), 214 deletions(-) diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 3d3d1c2..8c701e1 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -263,7 +263,7 @@ func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, }, "slug change failure response") } - userSession.GetInteraction().Redraw() + userSession.Interaction().Redraw() return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, Payload: &proto.Node_SlugEventResponse{ @@ -321,7 +321,7 @@ func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto }, "terminate session fetch failed") } - if err = userSession.GetLifecycle().Close(); err != nil { + if err = userSession.Lifecycle().Close(); err != nil { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_TERMINATE_SESSION, Payload: &proto.Node_TerminateSessionEventResponse{ diff --git a/server/http.go b/server/http.go index 2420686..e2143d5 100644 --- a/server/http.go +++ b/server/http.go @@ -335,8 +335,8 @@ func (hs *httpServer) handler(conn net.Conn) { return } -func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) { - payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) +func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { + payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) type channelResult struct { channel ssh.Channel @@ -346,7 +346,7 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi resultChan := make(chan channelResult, 1) go func() { - channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) resultChan <- channelResult{channel, reqs, err} }() @@ -357,14 +357,14 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi case result := <-resultChan: if result.err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) + sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) return } channel = result.channel reqs = result.reqs case <-time.After(5 * time.Second): log.Printf("Timeout opening forwarded-tcpip channel") - sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) + sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) return } @@ -390,6 +390,6 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi return } - sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) + sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) return } diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 3d32a43..c8478a4 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -30,50 +30,50 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { return io.CopyBuffer(dst, src, buf) } -type Forwarder struct { +type forwarder struct { listener net.Listener tunnelType types.TunnelType forwardedPort uint16 - slugManager slug.Manager + slug slug.Slug lifecycle Lifecycle } -func NewForwarder(slugManager slug.Manager) *Forwarder { - return &Forwarder{ +func New(slug slug.Slug) Forwarder { + return &forwarder{ listener: nil, - tunnelType: "", + tunnelType: types.UNKNOWN, forwardedPort: 0, - slugManager: slugManager, + slug: slug, lifecycle: nil, } } type Lifecycle interface { - GetConnection() ssh.Conn + Connection() ssh.Conn } -type ForwardingController interface { - AcceptTCPConnections() +type Forwarder interface { SetType(tunnelType types.TunnelType) - GetTunnelType() types.TunnelType - GetForwardedPort() uint16 + SetLifecycle(lifecycle Lifecycle) SetForwardedPort(port uint16) SetListener(listener net.Listener) - GetListener() net.Listener - Close() error + Listener() net.Listener + TunnelType() types.TunnelType + ForwardedPort() uint16 HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) - SetLifecycle(lifecycle Lifecycle) CreateForwardedTCPIPPayload(origin net.Addr) []byte WriteBadGatewayResponse(dst io.Writer) + AcceptTCPConnections() + Close() error } -func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { +func (f *forwarder) SetLifecycle(lifecycle Lifecycle) { f.lifecycle = lifecycle } -func (f *Forwarder) AcceptTCPConnections() { +func (f *forwarder) AcceptTCPConnections() { for { - conn, err := f.GetListener().Accept() + conn, err := f.Listener().Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -100,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() { resultChan := make(chan channelResult, 1) go func() { - channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload) resultChan <- channelResult{channel, reqs, err} }() @@ -130,7 +130,7 @@ func (f *Forwarder) AcceptTCPConnections() { } } -func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { +func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { defer func() { _, err := io.Copy(io.Discard, src) if err != nil { @@ -174,31 +174,31 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA wg.Wait() } -func (f *Forwarder) SetType(tunnelType types.TunnelType) { +func (f *forwarder) SetType(tunnelType types.TunnelType) { f.tunnelType = tunnelType } -func (f *Forwarder) GetTunnelType() types.TunnelType { +func (f *forwarder) TunnelType() types.TunnelType { return f.tunnelType } -func (f *Forwarder) GetForwardedPort() uint16 { +func (f *forwarder) ForwardedPort() uint16 { return f.forwardedPort } -func (f *Forwarder) SetForwardedPort(port uint16) { +func (f *forwarder) SetForwardedPort(port uint16) { f.forwardedPort = port } -func (f *Forwarder) SetListener(listener net.Listener) { +func (f *forwarder) SetListener(listener net.Listener) { f.listener = listener } -func (f *Forwarder) GetListener() net.Listener { +func (f *forwarder) Listener() net.Listener { return f.listener } -func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { +func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) { _, err := dst.Write(types.BadGatewayResponse) if err != nil { log.Printf("failed to write Bad Gateway response: %v", err) @@ -206,20 +206,20 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { } } -func (f *Forwarder) Close() error { - if f.GetListener() != nil { +func (f *forwarder) Close() error { + if f.Listener() != nil { return f.listener.Close() } return nil } -func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { +func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { var buf bytes.Buffer host, originPort := parseAddr(origin.String()) writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort())) + err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort())) if err != nil { log.Printf("Failed to write string to buffer: %v", err) return nil diff --git a/session/handler.go b/session/handler.go index c5aad63..f80f222 100644 --- a/session/handler.go +++ b/session/handler.go @@ -15,7 +15,7 @@ import ( var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { +func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { for req := range GlobalRequest { switch req.Type { case "shell", "pty-req": @@ -56,7 +56,7 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { } } -func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { +func (s *session) HandleTCPIPForward(req *ssh.Request) { log.Println("Port forwarding request detected") fail := func(msg string) { @@ -103,7 +103,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { } } -func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { +func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) { fail := func(msg string, key *types.SessionKey) { log.Println(msg) if key != nil { @@ -137,11 +137,11 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { s.forwarder.SetType(types.HTTP) s.forwarder.SetForwardedPort(portToBind) - s.slugManager.Set(slug) + s.slug.Set(slug) s.lifecycle.SetStatus(types.RUNNING) } -func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { +func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { fail := func(msg string) { log.Println(msg) if err := req.Reply(false, nil); err != nil { @@ -219,7 +219,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.forwarder.SetType(types.TCP) s.forwarder.SetListener(listener) s.forwarder.SetForwardedPort(portToBind) - s.slugManager.Set(key.Id) + s.slug.Set(key.Id) s.lifecycle.SetStatus(types.RUNNING) go s.forwarder.AcceptTCPConnections() } diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 6d37f69..86bbec7 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -23,34 +23,34 @@ import ( type Lifecycle interface { Close() error - GetUser() string + User() string } type SessionRegistry interface { Update(user string, oldKey, newKey types.SessionKey) error } -type Controller interface { +type Interaction interface { + Mode() types.Mode SetChannel(channel ssh.Channel) SetLifecycle(lifecycle Lifecycle) - Start() - SetWH(w, h int) - Redraw() SetSessionRegistry(registry SessionRegistry) SetMode(m types.Mode) - GetMode() types.Mode + SetWH(w, h int) + Start() + Redraw() Send(message string) error } type Forwarder interface { Close() error - GetTunnelType() types.TunnelType - GetForwardedPort() uint16 + TunnelType() types.TunnelType + ForwardedPort() uint16 } -type Interaction struct { +type interaction struct { channel ssh.Channel - slugManager slug.Manager + slug slug.Slug forwarder Forwarder lifecycle Lifecycle sessionRegistry SessionRegistry @@ -60,22 +60,22 @@ type Interaction struct { mode types.Mode } -func (i *Interaction) SetMode(m types.Mode) { +func (i *interaction) SetMode(m types.Mode) { i.mode = m } -func (i *Interaction) GetMode() types.Mode { +func (i *interaction) Mode() types.Mode { return i.mode } -func (i *Interaction) Send(message string) error { +func (i *interaction) Send(message string) error { if i.channel != nil { _, err := i.channel.Write([]byte(message)) return err } return nil } -func (i *Interaction) SetWH(w, h int) { +func (i *interaction) SetWH(w, h int) { if i.program != nil { i.program.Send(tea.WindowSizeMsg{ Width: w, @@ -103,14 +103,14 @@ type model struct { commandList list.Model slugInput textinput.Model slugError string - interaction *Interaction + interaction *interaction width int height int } func (m *model) getTunnelURL() string { if m.tunnelType == types.HTTP { - return buildURL(m.protocol, m.interaction.slugManager.Get(), m.domain) + return buildURL(m.protocol, m.interaction.slug.String(), m.domain) } return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) } @@ -123,11 +123,11 @@ type keymap struct { type tickMsg time.Time -func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction { +func New(slug slug.Slug, forwarder Forwarder) Interaction { ctx, cancel := context.WithCancel(context.Background()) - return &Interaction{ + return &interaction{ channel: nil, - slugManager: slugManager, + slug: slug, forwarder: forwarder, lifecycle: nil, sessionRegistry: nil, @@ -137,19 +137,19 @@ func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction } } -func (i *Interaction) SetSessionRegistry(registry SessionRegistry) { +func (i *interaction) SetSessionRegistry(registry SessionRegistry) { i.sessionRegistry = registry } -func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { +func (i *interaction) SetLifecycle(lifecycle Lifecycle) { i.lifecycle = lifecycle } -func (i *Interaction) SetChannel(channel ssh.Channel) { +func (i *interaction) SetChannel(channel ssh.Channel) { i.channel = channel } -func (i *Interaction) Stop() { +func (i *interaction) Stop() { if i.cancel != nil { i.cancel() } @@ -242,8 +242,8 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(tea.ClearScreen, textinput.Blink) case "enter": inputValue := m.slugInput.Value() - if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{ - Id: m.interaction.slugManager.Get(), + if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{ + Id: m.interaction.slug.String(), Type: types.HTTP, }, types.SessionKey{ Id: inputValue, @@ -285,7 +285,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if item.name == "slug" { m.showingCommands = false m.editingSlug = true - m.slugInput.SetValue(m.interaction.slugManager.Get()) + m.slugInput.SetValue(m.interaction.slug.String()) m.slugInput.Focus() return m, tea.Batch(tea.ClearScreen, textinput.Blink) } else if item.name == "tunnel-type" { @@ -317,7 +317,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, nil } -func (i *Interaction) Redraw() { +func (i *interaction) Redraw() { if i.program != nil { i.program.Send(tea.ClearScreen()) } @@ -691,7 +691,7 @@ func (m *model) View() string { MarginBottom(boxMargin). Width(boxMaxWidth) - authenticatedUser := m.interaction.lifecycle.GetUser() + authenticatedUser := m.interaction.lifecycle.User() userInfoStyle := lipgloss.NewStyle(). Foreground(lipgloss.Color("#FAFAFA")). @@ -767,7 +767,7 @@ func (m *model) View() string { return b.String() } -func (i *Interaction) Start() { +func (i *interaction) Start() { if i.mode == types.HEADLESS { return } @@ -779,8 +779,8 @@ func (i *Interaction) Start() { protocol = "https" } - tunnelType := i.forwarder.GetTunnelType() - port := i.forwarder.GetForwardedPort() + tunnelType := i.forwarder.TunnelType() + port := i.forwarder.ForwardedPort() items := []list.Item{ commandItem{name: "slug", desc: "Set custom subdomain"}, diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index ccc01f0..704b4a8 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -15,115 +15,119 @@ import ( type Forwarder interface { Close() error - GetTunnelType() types.TunnelType - GetForwardedPort() uint16 + TunnelType() types.TunnelType + ForwardedPort() uint16 } type SessionRegistry interface { Remove(key types.SessionKey) } -type Lifecycle struct { +type lifecycle struct { status types.Status conn ssh.Conn channel ssh.Channel forwarder Forwarder sessionRegistry SessionRegistry - slugManager slug.Manager + slug slug.Slug startedAt time.Time user string } -func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle { - return &Lifecycle{ +func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle { + return &lifecycle{ status: types.INITIALIZING, conn: conn, channel: nil, forwarder: forwarder, - slugManager: slugManager, + slug: slugManager, sessionRegistry: nil, startedAt: time.Now(), user: user, } } -func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) { +func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) { l.sessionRegistry = registry } -type SessionLifecycle interface { - Close() error - SetStatus(status types.Status) - GetConnection() ssh.Conn - GetChannel() ssh.Channel - GetUser() string +type Lifecycle interface { + Connection() ssh.Conn + Channel() ssh.Channel + User() string SetChannel(channel ssh.Channel) SetSessionRegistry(registry SessionRegistry) + SetStatus(status types.Status) IsActive() bool StartedAt() time.Time + Close() error } -func (l *Lifecycle) GetUser() string { +func (l *lifecycle) User() string { return l.user } -func (l *Lifecycle) GetChannel() ssh.Channel { +func (l *lifecycle) Channel() ssh.Channel { return l.channel } -func (l *Lifecycle) SetChannel(channel ssh.Channel) { +func (l *lifecycle) SetChannel(channel ssh.Channel) { l.channel = channel } -func (l *Lifecycle) GetConnection() ssh.Conn { +func (l *lifecycle) Connection() ssh.Conn { return l.conn } -func (l *Lifecycle) SetStatus(status types.Status) { +func (l *lifecycle) SetStatus(status types.Status) { l.status = status if status == types.RUNNING && l.startedAt.IsZero() { l.startedAt = time.Now() } } -func (l *Lifecycle) Close() error { - err := l.forwarder.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err +func (l *lifecycle) Close() error { + var firstErr error + tunnelType := l.forwarder.TunnelType() + + if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + firstErr = err } if l.channel != nil { - err := l.channel.Close() - if err != nil && !errors.Is(err, io.EOF) { - return err + if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + if firstErr == nil { + firstErr = err + } } } if l.conn != nil { - err := l.conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err + if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + if firstErr == nil { + firstErr = err + } } } - clientSlug := l.slugManager.Get() - if clientSlug != "" && l.sessionRegistry.Remove != nil { - key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()} - l.sessionRegistry.Remove(key) + clientSlug := l.slug.String() + key := types.SessionKey{ + Id: clientSlug, + Type: tunnelType, } + l.sessionRegistry.Remove(key) - if l.forwarder.GetTunnelType() == types.TCP { - err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false) - if err != nil { - return err + if tunnelType == types.TCP { + if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { + firstErr = err } } - return nil + return firstErr } -func (l *Lifecycle) IsActive() bool { +func (l *lifecycle) IsActive() bool { return l.status == types.RUNNING } -func (l *Lifecycle) StartedAt() time.Time { +func (l *lifecycle) StartedAt() time.Time { return l.startedAt } diff --git a/session/registry.go b/session/registry.go index 3113dd6..6698cf1 100644 --- a/session/registry.go +++ b/session/registry.go @@ -9,27 +9,27 @@ import ( type Key = types.SessionKey type Registry interface { - Get(key Key) (session *SSHSession, err error) - GetWithUser(user string, key Key) (session *SSHSession, err error) + Get(key Key) (session Session, err error) + GetWithUser(user string, key Key) (session Session, err error) Update(user string, oldKey, newKey Key) error - Register(key Key, session *SSHSession) (success bool) + Register(key Key, session Session) (success bool) Remove(key Key) - GetAllSessionFromUser(user string) []*SSHSession + GetAllSessionFromUser(user string) []Session } type registry struct { mu sync.RWMutex - byUser map[string]map[Key]*SSHSession + byUser map[string]map[Key]Session slugIndex map[Key]string } func NewRegistry() Registry { return ®istry{ - byUser: make(map[string]map[Key]*SSHSession), + byUser: make(map[string]map[Key]Session), slugIndex: make(map[Key]string), } } -func (r *registry) Get(key Key) (session *SSHSession, err error) { +func (r *registry) Get(key Key) (session Session, err error) { r.mu.RLock() defer r.mu.RUnlock() @@ -45,7 +45,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) { return client, nil } -func (r *registry) GetWithUser(user string, key Key) (session *SSHSession, err error) { +func (r *registry) GetWithUser(user string, key Key) (session Session, err error) { r.mu.RLock() defer r.mu.RUnlock() @@ -87,17 +87,17 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { delete(r.byUser[user], oldKey) delete(r.slugIndex, oldKey) - client.slugManager.Set(newKey.Id) + client.Slug().Set(newKey.Id) r.slugIndex[newKey] = user if r.byUser[user] == nil { - r.byUser[user] = make(map[Key]*SSHSession) + r.byUser[user] = make(map[Key]Session) } r.byUser[user][newKey] = client return nil } -func (r *registry) Register(key Key, session *SSHSession) (success bool) { +func (r *registry) Register(key Key, session Session) (success bool) { r.mu.Lock() defer r.mu.Unlock() @@ -105,9 +105,9 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) { return false } - userID := session.lifecycle.GetUser() + userID := session.Lifecycle().User() if r.byUser[userID] == nil { - r.byUser[userID] = make(map[Key]*SSHSession) + r.byUser[userID] = make(map[Key]Session) } r.byUser[userID][key] = session @@ -115,16 +115,16 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) { return true } -func (r *registry) GetAllSessionFromUser(user string) []*SSHSession { +func (r *registry) GetAllSessionFromUser(user string) []Session { r.mu.RLock() defer r.mu.RUnlock() m := r.byUser[user] if len(m) == 0 { - return []*SSHSession{} + return []Session{} } - sessions := make([]*SSHSession, 0, len(m)) + sessions := make([]Session, 0, len(m)) for _, s := range m { sessions = append(sessions, s) } diff --git a/session/session.go b/session/session.go index 98a6cd4..e01355c 100644 --- a/session/session.go +++ b/session/session.go @@ -14,61 +14,6 @@ import ( "golang.org/x/crypto/ssh" ) -type Session interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) -} - -type SSHSession struct { - initialReq <-chan *ssh.Request - sshReqChannel <-chan ssh.NewChannel - lifecycle lifecycle.SessionLifecycle - interaction interaction.Controller - forwarder forwarder.ForwardingController - slugManager slug.Manager - registry Registry -} - -func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { - return s.lifecycle -} - -func (s *SSHSession) GetInteraction() interaction.Controller { - return s.interaction -} - -func (s *SSHSession) GetForwarder() forwarder.ForwardingController { - return s.forwarder -} - -func (s *SSHSession) GetSlugManager() slug.Manager { - return s.slugManager -} - -func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession { - slugManager := slug.NewManager() - forwarderManager := forwarder.NewForwarder(slugManager) - interactionManager := interaction.NewInteraction(slugManager, forwarderManager) - lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user) - - interactionManager.SetLifecycle(lifecycleManager) - forwarderManager.SetLifecycle(lifecycleManager) - interactionManager.SetSessionRegistry(sessionRegistry) - lifecycleManager.SetSessionRegistry(sessionRegistry) - - return &SSHSession{ - initialReq: forwardingReq, - sshReqChannel: sshChan, - lifecycle: lifecycleManager, - interaction: interactionManager, - forwarder: forwarderManager, - slugManager: slugManager, - registry: sessionRegistry, - } -} - type Detail struct { ForwardingType string `json:"forwarding_type,omitempty"` Slug string `json:"slug,omitempty"` @@ -77,21 +22,90 @@ type Detail struct { StartedAt time.Time `json:"started_at,omitempty"` } -func (s *SSHSession) Detail() Detail { - return Detail{ - ForwardingType: string(s.forwarder.GetTunnelType()), - Slug: s.slugManager.Get(), - UserID: s.lifecycle.GetUser(), +type Session interface { + HandleGlobalRequest(ch <-chan *ssh.Request) + HandleTCPIPForward(req *ssh.Request) + HandleHTTPForward(req *ssh.Request, port uint16) + HandleTCPForward(req *ssh.Request, addr string, port uint16) + Lifecycle() lifecycle.Lifecycle + Interaction() interaction.Interaction + Forwarder() forwarder.Forwarder + Slug() slug.Slug + Detail() *Detail + Start() error +} + +type session struct { + initialReq <-chan *ssh.Request + sshChan <-chan ssh.NewChannel + lifecycle lifecycle.Lifecycle + interaction interaction.Interaction + forwarder forwarder.Forwarder + slug slug.Slug + registry Registry +} + +func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session { + slugManager := slug.New() + forwarderManager := forwarder.New(slugManager) + interactionManager := interaction.New(slugManager, forwarderManager) + lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user) + + interactionManager.SetLifecycle(lifecycleManager) + forwarderManager.SetLifecycle(lifecycleManager) + interactionManager.SetSessionRegistry(sessionRegistry) + lifecycleManager.SetSessionRegistry(sessionRegistry) + + return &session{ + initialReq: initialReq, + sshChan: sshChan, + lifecycle: lifecycleManager, + interaction: interactionManager, + forwarder: forwarderManager, + slug: slugManager, + registry: sessionRegistry, + } +} + +func (s *session) Lifecycle() lifecycle.Lifecycle { + return s.lifecycle +} + +func (s *session) Interaction() interaction.Interaction { + return s.interaction +} + +func (s *session) Forwarder() forwarder.Forwarder { + return s.forwarder +} + +func (s *session) Slug() slug.Slug { + return s.slug +} + +func (s *session) Detail() *Detail { + var tunnelType string + if s.forwarder.TunnelType() == types.HTTP { + tunnelType = "HTTP" + } else if s.forwarder.TunnelType() == types.TCP { + tunnelType = "TCP" + } else { + tunnelType = "UNKNOWN" + } + return &Detail{ + ForwardingType: tunnelType, + Slug: s.slug.String(), + UserID: s.lifecycle.User(), Active: s.lifecycle.IsActive(), StartedAt: s.lifecycle.StartedAt(), } } -func (s *SSHSession) Start() error { +func (s *session) Start() error { var channel ssh.NewChannel var ok bool select { - case channel, ok = <-s.sshReqChannel: + case channel, ok = <-s.sshChan: if !ok { log.Println("Forwarding request channel closed") return nil @@ -122,7 +136,7 @@ func (s *SSHSession) Start() error { return fmt.Errorf("no forwarding Request") } - if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.GetUser() == "UNAUTHORIZED" { + if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" { if err := tcpipReq.Reply(false, nil); err != nil { log.Printf("cannot reply to tcpip req: %s\n", err) return err @@ -137,7 +151,7 @@ func (s *SSHSession) Start() error { s.HandleTCPIPForward(tcpipReq) s.interaction.Start() - s.lifecycle.GetConnection().Wait() + s.lifecycle.Connection().Wait() if err := s.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) return err @@ -145,7 +159,7 @@ func (s *SSHSession) Start() error { return nil } -func (s *SSHSession) waitForTCPIPForward() *ssh.Request { +func (s *session) waitForTCPIPForward() *ssh.Request { select { case req, ok := <-s.initialReq: if !ok { diff --git a/session/slug/slug.go b/session/slug/slug.go index 7ab4697..b9684d1 100644 --- a/session/slug/slug.go +++ b/session/slug/slug.go @@ -1,24 +1,24 @@ package slug -type Manager interface { - Get() string +type Slug interface { + String() string Set(slug string) } -type manager struct { +type slug struct { slug string } -func NewManager() Manager { - return &manager{ +func New() Slug { + return &slug{ slug: "", } } -func (s *manager) Get() string { +func (s *slug) String() string { return s.slug } -func (s *manager) Set(slug string) { +func (s *slug) Set(slug string) { s.slug = slug } diff --git a/types/types.go b/types/types.go index bb8d199..148cd2b 100644 --- a/types/types.go +++ b/types/types.go @@ -1,26 +1,25 @@ package types -type Status string +type Status int const ( - INITIALIZING Status = "INITIALIZING" - RUNNING Status = "RUNNING" - SETUP Status = "SETUP" + INITIALIZING Status = iota + RUNNING ) -type Mode string +type Mode int const ( - INTERACTIVE Mode = "INTERACTIVE" - HEADLESS Mode = "HEADLESS" + INTERACTIVE Mode = iota + HEADLESS ) -type TunnelType string +type TunnelType int const ( - UNKNOWN TunnelType = "UNKNOWN" - HTTP TunnelType = "HTTP" - TCP TunnelType = "TCP" + UNKNOWN TunnelType = iota + HTTP + TCP ) type SessionKey struct {