refactor: replace Get/Set patterns with idiomatic Go interfaces
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m4s

- rename constructors to New
- remove Get/Set-style accessors
- replace string-based enums with iota-backed types
This commit is contained in:
2026-01-14 15:28:17 +07:00
parent ae3ed52d16
commit dbdf8094fa
10 changed files with 231 additions and 214 deletions
+2 -2
View File
@@ -263,7 +263,7 @@ func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response") }, "slug change failure response")
} }
userSession.GetInteraction().Redraw() userSession.Interaction().Redraw()
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
@@ -321,7 +321,7 @@ func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto
}, "terminate session fetch failed") }, "terminate session fetch failed")
} }
if err = userSession.GetLifecycle().Close(); err != nil { if err = userSession.Lifecycle().Close(); err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION, Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{ Payload: &proto.Node_TerminateSessionEventResponse{
+6 -6
View File
@@ -335,8 +335,8 @@ func (hs *httpServer) handler(conn net.Conn) {
return return
} }
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) { func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
@@ -346,7 +346,7 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -357,14 +357,14 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
case result := <-resultChan: case result := <-resultChan:
if result.err != nil { if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
channel = result.channel channel = result.channel
reqs = result.reqs reqs = result.reqs
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel") log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
@@ -390,6 +390,6 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
return return
} }
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return return
} }
+30 -30
View File
@@ -30,50 +30,50 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
return io.CopyBuffer(dst, src, buf) return io.CopyBuffer(dst, src, buf)
} }
type Forwarder struct { type forwarder struct {
listener net.Listener listener net.Listener
tunnelType types.TunnelType tunnelType types.TunnelType
forwardedPort uint16 forwardedPort uint16
slugManager slug.Manager slug slug.Slug
lifecycle Lifecycle lifecycle Lifecycle
} }
func NewForwarder(slugManager slug.Manager) *Forwarder { func New(slug slug.Slug) Forwarder {
return &Forwarder{ return &forwarder{
listener: nil, listener: nil,
tunnelType: "", tunnelType: types.UNKNOWN,
forwardedPort: 0, forwardedPort: 0,
slugManager: slugManager, slug: slug,
lifecycle: nil, lifecycle: nil,
} }
} }
type Lifecycle interface { type Lifecycle interface {
GetConnection() ssh.Conn Connection() ssh.Conn
} }
type ForwardingController interface { type Forwarder interface {
AcceptTCPConnections()
SetType(tunnelType types.TunnelType) SetType(tunnelType types.TunnelType)
GetTunnelType() types.TunnelType SetLifecycle(lifecycle Lifecycle)
GetForwardedPort() uint16
SetForwardedPort(port uint16) SetForwardedPort(port uint16)
SetListener(listener net.Listener) SetListener(listener net.Listener)
GetListener() net.Listener Listener() net.Listener
Close() error TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
AcceptTCPConnections()
Close() error
} }
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *forwarder) SetLifecycle(lifecycle Lifecycle) {
f.lifecycle = lifecycle f.lifecycle = lifecycle
} }
func (f *Forwarder) AcceptTCPConnections() { func (f *forwarder) AcceptTCPConnections() {
for { for {
conn, err := f.GetListener().Accept() conn, err := f.Listener().Accept()
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return return
@@ -100,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() {
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -130,7 +130,7 @@ func (f *Forwarder) AcceptTCPConnections() {
} }
} }
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func() { defer func() {
_, err := io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil { if err != nil {
@@ -174,31 +174,31 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
wg.Wait() wg.Wait()
} }
func (f *Forwarder) SetType(tunnelType types.TunnelType) { func (f *forwarder) SetType(tunnelType types.TunnelType) {
f.tunnelType = tunnelType f.tunnelType = tunnelType
} }
func (f *Forwarder) GetTunnelType() types.TunnelType { func (f *forwarder) TunnelType() types.TunnelType {
return f.tunnelType return f.tunnelType
} }
func (f *Forwarder) GetForwardedPort() uint16 { func (f *forwarder) ForwardedPort() uint16 {
return f.forwardedPort return f.forwardedPort
} }
func (f *Forwarder) SetForwardedPort(port uint16) { func (f *forwarder) SetForwardedPort(port uint16) {
f.forwardedPort = port f.forwardedPort = port
} }
func (f *Forwarder) SetListener(listener net.Listener) { func (f *forwarder) SetListener(listener net.Listener) {
f.listener = listener f.listener = listener
} }
func (f *Forwarder) GetListener() net.Listener { func (f *forwarder) Listener() net.Listener {
return f.listener return f.listener
} }
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse) _, err := dst.Write(types.BadGatewayResponse)
if err != nil { if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err) log.Printf("failed to write Bad Gateway response: %v", err)
@@ -206,20 +206,20 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
} }
} }
func (f *Forwarder) Close() error { func (f *forwarder) Close() error {
if f.GetListener() != nil { if f.Listener() != nil {
return f.listener.Close() return f.listener.Close()
} }
return nil return nil
} }
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
var buf bytes.Buffer var buf bytes.Buffer
host, originPort := parseAddr(origin.String()) host, originPort := parseAddr(origin.String())
writeSSHString(&buf, "localhost") writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort())) err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
if err != nil { if err != nil {
log.Printf("Failed to write string to buffer: %v", err) log.Printf("Failed to write string to buffer: %v", err)
return nil return nil
+6 -6
View File
@@ -15,7 +15,7 @@ import (
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest { for req := range GlobalRequest {
switch req.Type { switch req.Type {
case "shell", "pty-req": case "shell", "pty-req":
@@ -56,7 +56,7 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
} }
} }
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { func (s *session) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected") log.Println("Port forwarding request detected")
fail := func(msg string) { fail := func(msg string) {
@@ -103,7 +103,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
} }
} }
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
fail := func(msg string, key *types.SessionKey) { fail := func(msg string, key *types.SessionKey) {
log.Println(msg) log.Println(msg)
if key != nil { if key != nil {
@@ -137,11 +137,11 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
s.forwarder.SetType(types.HTTP) s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(slug) s.slug.Set(slug)
s.lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
} }
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
fail := func(msg string) { fail := func(msg string) {
log.Println(msg) log.Println(msg)
if err := req.Reply(false, nil); err != nil { if err := req.Reply(false, nil); err != nil {
@@ -219,7 +219,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
s.forwarder.SetType(types.TCP) s.forwarder.SetType(types.TCP)
s.forwarder.SetListener(listener) s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(key.Id) s.slug.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections() go s.forwarder.AcceptTCPConnections()
} }
+31 -31
View File
@@ -23,34 +23,34 @@ import (
type Lifecycle interface { type Lifecycle interface {
Close() error Close() error
GetUser() string User() string
} }
type SessionRegistry interface { type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error Update(user string, oldKey, newKey types.SessionKey) error
} }
type Controller interface { type Interaction interface {
Mode() types.Mode
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
Start()
SetWH(w, h int)
Redraw()
SetSessionRegistry(registry SessionRegistry) SetSessionRegistry(registry SessionRegistry)
SetMode(m types.Mode) SetMode(m types.Mode)
GetMode() types.Mode SetWH(w, h int)
Start()
Redraw()
Send(message string) error Send(message string) error
} }
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType TunnelType() types.TunnelType
GetForwardedPort() uint16 ForwardedPort() uint16
} }
type Interaction struct { type interaction struct {
channel ssh.Channel channel ssh.Channel
slugManager slug.Manager slug slug.Slug
forwarder Forwarder forwarder Forwarder
lifecycle Lifecycle lifecycle Lifecycle
sessionRegistry SessionRegistry sessionRegistry SessionRegistry
@@ -60,22 +60,22 @@ type Interaction struct {
mode types.Mode mode types.Mode
} }
func (i *Interaction) SetMode(m types.Mode) { func (i *interaction) SetMode(m types.Mode) {
i.mode = m i.mode = m
} }
func (i *Interaction) GetMode() types.Mode { func (i *interaction) Mode() types.Mode {
return i.mode return i.mode
} }
func (i *Interaction) Send(message string) error { func (i *interaction) Send(message string) error {
if i.channel != nil { if i.channel != nil {
_, err := i.channel.Write([]byte(message)) _, err := i.channel.Write([]byte(message))
return err return err
} }
return nil return nil
} }
func (i *Interaction) SetWH(w, h int) { func (i *interaction) SetWH(w, h int) {
if i.program != nil { if i.program != nil {
i.program.Send(tea.WindowSizeMsg{ i.program.Send(tea.WindowSizeMsg{
Width: w, Width: w,
@@ -103,14 +103,14 @@ type model struct {
commandList list.Model commandList list.Model
slugInput textinput.Model slugInput textinput.Model
slugError string slugError string
interaction *Interaction interaction *interaction
width int width int
height int height int
} }
func (m *model) getTunnelURL() string { func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP { if m.tunnelType == types.HTTP {
return buildURL(m.protocol, m.interaction.slugManager.Get(), m.domain) return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
} }
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
} }
@@ -123,11 +123,11 @@ type keymap struct {
type tickMsg time.Time type tickMsg time.Time
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction { func New(slug slug.Slug, forwarder Forwarder) Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Interaction{ return &interaction{
channel: nil, channel: nil,
slugManager: slugManager, slug: slug,
forwarder: forwarder, forwarder: forwarder,
lifecycle: nil, lifecycle: nil,
sessionRegistry: nil, sessionRegistry: nil,
@@ -137,19 +137,19 @@ func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction
} }
} }
func (i *Interaction) SetSessionRegistry(registry SessionRegistry) { func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
i.sessionRegistry = registry i.sessionRegistry = registry
} }
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle i.lifecycle = lifecycle
} }
func (i *Interaction) SetChannel(channel ssh.Channel) { func (i *interaction) SetChannel(channel ssh.Channel) {
i.channel = channel i.channel = channel
} }
func (i *Interaction) Stop() { func (i *interaction) Stop() {
if i.cancel != nil { if i.cancel != nil {
i.cancel() i.cancel()
} }
@@ -242,8 +242,8 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter": case "enter":
inputValue := m.slugInput.Value() inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{ if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{
Id: m.interaction.slugManager.Get(), Id: m.interaction.slug.String(),
Type: types.HTTP, Type: types.HTTP,
}, types.SessionKey{ }, types.SessionKey{
Id: inputValue, Id: inputValue,
@@ -285,7 +285,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if item.name == "slug" { if item.name == "slug" {
m.showingCommands = false m.showingCommands = false
m.editingSlug = true m.editingSlug = true
m.slugInput.SetValue(m.interaction.slugManager.Get()) m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus() m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" { } else if item.name == "tunnel-type" {
@@ -317,7 +317,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil return m, nil
} }
func (i *Interaction) Redraw() { func (i *interaction) Redraw() {
if i.program != nil { if i.program != nil {
i.program.Send(tea.ClearScreen()) i.program.Send(tea.ClearScreen())
} }
@@ -691,7 +691,7 @@ func (m *model) View() string {
MarginBottom(boxMargin). MarginBottom(boxMargin).
Width(boxMaxWidth) Width(boxMaxWidth)
authenticatedUser := m.interaction.lifecycle.GetUser() authenticatedUser := m.interaction.lifecycle.User()
userInfoStyle := lipgloss.NewStyle(). userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")). Foreground(lipgloss.Color("#FAFAFA")).
@@ -767,7 +767,7 @@ func (m *model) View() string {
return b.String() return b.String()
} }
func (i *Interaction) Start() { func (i *interaction) Start() {
if i.mode == types.HEADLESS { if i.mode == types.HEADLESS {
return return
} }
@@ -779,8 +779,8 @@ func (i *Interaction) Start() {
protocol = "https" protocol = "https"
} }
tunnelType := i.forwarder.GetTunnelType() tunnelType := i.forwarder.TunnelType()
port := i.forwarder.GetForwardedPort() port := i.forwarder.ForwardedPort()
items := []list.Item{ items := []list.Item{
commandItem{name: "slug", desc: "Set custom subdomain"}, commandItem{name: "slug", desc: "Set custom subdomain"},
+44 -40
View File
@@ -15,115 +15,119 @@ import (
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType TunnelType() types.TunnelType
GetForwardedPort() uint16 ForwardedPort() uint16
} }
type SessionRegistry interface { type SessionRegistry interface {
Remove(key types.SessionKey) Remove(key types.SessionKey)
} }
type Lifecycle struct { type lifecycle struct {
status types.Status status types.Status
conn ssh.Conn conn ssh.Conn
channel ssh.Channel channel ssh.Channel
forwarder Forwarder forwarder Forwarder
sessionRegistry SessionRegistry sessionRegistry SessionRegistry
slugManager slug.Manager slug slug.Slug
startedAt time.Time startedAt time.Time
user string user string
} }
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
return &Lifecycle{ return &lifecycle{
status: types.INITIALIZING, status: types.INITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
slugManager: slugManager, slug: slugManager,
sessionRegistry: nil, sessionRegistry: nil,
startedAt: time.Now(), startedAt: time.Now(),
user: user, user: user,
} }
} }
func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) { func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
l.sessionRegistry = registry l.sessionRegistry = registry
} }
type SessionLifecycle interface { type Lifecycle interface {
Close() error Connection() ssh.Conn
SetStatus(status types.Status) Channel() ssh.Channel
GetConnection() ssh.Conn User() string
GetChannel() ssh.Channel
GetUser() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetSessionRegistry(registry SessionRegistry) SetSessionRegistry(registry SessionRegistry)
SetStatus(status types.Status)
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
Close() error
} }
func (l *Lifecycle) GetUser() string { func (l *lifecycle) User() string {
return l.user return l.user
} }
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *lifecycle) Channel() ssh.Channel {
return l.channel return l.channel
} }
func (l *Lifecycle) SetChannel(channel ssh.Channel) { func (l *lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel l.channel = channel
} }
func (l *Lifecycle) GetConnection() ssh.Conn { func (l *lifecycle) Connection() ssh.Conn {
return l.conn return l.conn
} }
func (l *Lifecycle) SetStatus(status types.Status) { func (l *lifecycle) SetStatus(status types.Status) {
l.status = status l.status = status
if status == types.RUNNING && l.startedAt.IsZero() { if status == types.RUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now() l.startedAt = time.Now()
} }
} }
func (l *Lifecycle) Close() error { func (l *lifecycle) Close() error {
err := l.forwarder.Close() var firstErr error
if err != nil && !errors.Is(err, net.ErrClosed) { tunnelType := l.forwarder.TunnelType()
return err
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
firstErr = err
} }
if l.channel != nil { if l.channel != nil {
err := l.channel.Close() if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
if err != nil && !errors.Is(err, io.EOF) { if firstErr == nil {
return err firstErr = err
}
} }
} }
if l.conn != nil { if l.conn != nil {
err := l.conn.Close() if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if err != nil && !errors.Is(err, net.ErrClosed) { if firstErr == nil {
return err firstErr = err
}
} }
} }
clientSlug := l.slugManager.Get() clientSlug := l.slug.String()
if clientSlug != "" && l.sessionRegistry.Remove != nil { key := types.SessionKey{
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()} Id: clientSlug,
Type: tunnelType,
}
l.sessionRegistry.Remove(key) l.sessionRegistry.Remove(key)
}
if l.forwarder.GetTunnelType() == types.TCP { if tunnelType == types.TCP {
err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false) if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
if err != nil { firstErr = err
return err
} }
} }
return nil return firstErr
} }
func (l *Lifecycle) IsActive() bool { func (l *lifecycle) IsActive() bool {
return l.status == types.RUNNING return l.status == types.RUNNING
} }
func (l *Lifecycle) StartedAt() time.Time { func (l *lifecycle) StartedAt() time.Time {
return l.startedAt return l.startedAt
} }
+16 -16
View File
@@ -9,27 +9,27 @@ import (
type Key = types.SessionKey type Key = types.SessionKey
type Registry interface { type Registry interface {
Get(key Key) (session *SSHSession, err error) Get(key Key) (session Session, err error)
GetWithUser(user string, key Key) (session *SSHSession, err error) GetWithUser(user string, key Key) (session Session, err error)
Update(user string, oldKey, newKey Key) error Update(user string, oldKey, newKey Key) error
Register(key Key, session *SSHSession) (success bool) Register(key Key, session Session) (success bool)
Remove(key Key) Remove(key Key)
GetAllSessionFromUser(user string) []*SSHSession GetAllSessionFromUser(user string) []Session
} }
type registry struct { type registry struct {
mu sync.RWMutex mu sync.RWMutex
byUser map[string]map[Key]*SSHSession byUser map[string]map[Key]Session
slugIndex map[Key]string slugIndex map[Key]string
} }
func NewRegistry() Registry { func NewRegistry() Registry {
return &registry{ return &registry{
byUser: make(map[string]map[Key]*SSHSession), byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string), slugIndex: make(map[Key]string),
} }
} }
func (r *registry) Get(key Key) (session *SSHSession, err error) { func (r *registry) Get(key Key) (session Session, err error) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
@@ -45,7 +45,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) {
return client, nil return client, nil
} }
func (r *registry) GetWithUser(user string, key Key) (session *SSHSession, err error) { func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
@@ -87,17 +87,17 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
delete(r.byUser[user], oldKey) delete(r.byUser[user], oldKey)
delete(r.slugIndex, oldKey) delete(r.slugIndex, oldKey)
client.slugManager.Set(newKey.Id) client.Slug().Set(newKey.Id)
r.slugIndex[newKey] = user r.slugIndex[newKey] = user
if r.byUser[user] == nil { if r.byUser[user] == nil {
r.byUser[user] = make(map[Key]*SSHSession) r.byUser[user] = make(map[Key]Session)
} }
r.byUser[user][newKey] = client r.byUser[user][newKey] = client
return nil return nil
} }
func (r *registry) Register(key Key, session *SSHSession) (success bool) { func (r *registry) Register(key Key, session Session) (success bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
@@ -105,9 +105,9 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
return false return false
} }
userID := session.lifecycle.GetUser() userID := session.Lifecycle().User()
if r.byUser[userID] == nil { if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]*SSHSession) r.byUser[userID] = make(map[Key]Session)
} }
r.byUser[userID][key] = session r.byUser[userID][key] = session
@@ -115,16 +115,16 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
return true return true
} }
func (r *registry) GetAllSessionFromUser(user string) []*SSHSession { func (r *registry) GetAllSessionFromUser(user string) []Session {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
m := r.byUser[user] m := r.byUser[user]
if len(m) == 0 { if len(m) == 0 {
return []*SSHSession{} return []Session{}
} }
sessions := make([]*SSHSession, 0, len(m)) sessions := make([]Session, 0, len(m))
for _, s := range m { for _, s := range m {
sessions = append(sessions, s) sessions = append(sessions, s)
} }
+79 -65
View File
@@ -14,61 +14,6 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
}
type SSHSession struct {
initialReq <-chan *ssh.Request
sshReqChannel <-chan ssh.NewChannel
lifecycle lifecycle.SessionLifecycle
interaction interaction.Controller
forwarder forwarder.ForwardingController
slugManager slug.Manager
registry Registry
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
return s.lifecycle
}
func (s *SSHSession) GetInteraction() interaction.Controller {
return s.interaction
}
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
return s.forwarder
}
func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession {
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &SSHSession{
initialReq: forwardingReq,
sshReqChannel: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
}
}
type Detail struct { type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"` ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"` Slug string `json:"slug,omitempty"`
@@ -77,21 +22,90 @@ type Detail struct {
StartedAt time.Time `json:"started_at,omitempty"` StartedAt time.Time `json:"started_at,omitempty"`
} }
func (s *SSHSession) Detail() Detail { type Session interface {
return Detail{ HandleGlobalRequest(ch <-chan *ssh.Request)
ForwardingType: string(s.forwarder.GetTunnelType()), HandleTCPIPForward(req *ssh.Request)
Slug: s.slugManager.Get(), HandleHTTPForward(req *ssh.Request, port uint16)
UserID: s.lifecycle.GetUser(), HandleTCPForward(req *ssh.Request, addr string, port uint16)
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *Detail
Start() error
}
type session struct {
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle
interaction interaction.Interaction
forwarder forwarder.Forwarder
slug slug.Slug
registry Registry
}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(slugManager)
interactionManager := interaction.New(slugManager, forwarderManager)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &session{
initialReq: initialReq,
sshChan: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slug: slugManager,
registry: sessionRegistry,
}
}
func (s *session) Lifecycle() lifecycle.Lifecycle {
return s.lifecycle
}
func (s *session) Interaction() interaction.Interaction {
return s.interaction
}
func (s *session) Forwarder() forwarder.Forwarder {
return s.forwarder
}
func (s *session) Slug() slug.Slug {
return s.slug
}
func (s *session) Detail() *Detail {
var tunnelType string
if s.forwarder.TunnelType() == types.HTTP {
tunnelType = "HTTP"
} else if s.forwarder.TunnelType() == types.TCP {
tunnelType = "TCP"
} else {
tunnelType = "UNKNOWN"
}
return &Detail{
ForwardingType: tunnelType,
Slug: s.slug.String(),
UserID: s.lifecycle.User(),
Active: s.lifecycle.IsActive(), Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(), StartedAt: s.lifecycle.StartedAt(),
} }
} }
func (s *SSHSession) Start() error { func (s *session) Start() error {
var channel ssh.NewChannel var channel ssh.NewChannel
var ok bool var ok bool
select { select {
case channel, ok = <-s.sshReqChannel: case channel, ok = <-s.sshChan:
if !ok { if !ok {
log.Println("Forwarding request channel closed") log.Println("Forwarding request channel closed")
return nil return nil
@@ -122,7 +136,7 @@ func (s *SSHSession) Start() error {
return fmt.Errorf("no forwarding Request") return fmt.Errorf("no forwarding Request")
} }
if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.GetUser() == "UNAUTHORIZED" { if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" {
if err := tcpipReq.Reply(false, nil); err != nil { if err := tcpipReq.Reply(false, nil); err != nil {
log.Printf("cannot reply to tcpip req: %s\n", err) log.Printf("cannot reply to tcpip req: %s\n", err)
return err return err
@@ -137,7 +151,7 @@ func (s *SSHSession) Start() error {
s.HandleTCPIPForward(tcpipReq) s.HandleTCPIPForward(tcpipReq)
s.interaction.Start() s.interaction.Start()
s.lifecycle.GetConnection().Wait() s.lifecycle.Connection().Wait()
if err := s.lifecycle.Close(); err != nil { if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
return err return err
@@ -145,7 +159,7 @@ func (s *SSHSession) Start() error {
return nil return nil
} }
func (s *SSHSession) waitForTCPIPForward() *ssh.Request { func (s *session) waitForTCPIPForward() *ssh.Request {
select { select {
case req, ok := <-s.initialReq: case req, ok := <-s.initialReq:
if !ok { if !ok {
+7 -7
View File
@@ -1,24 +1,24 @@
package slug package slug
type Manager interface { type Slug interface {
Get() string String() string
Set(slug string) Set(slug string)
} }
type manager struct { type slug struct {
slug string slug string
} }
func NewManager() Manager { func New() Slug {
return &manager{ return &slug{
slug: "", slug: "",
} }
} }
func (s *manager) Get() string { func (s *slug) String() string {
return s.slug return s.slug
} }
func (s *manager) Set(slug string) { func (s *slug) Set(slug string) {
s.slug = slug s.slug = slug
} }
+10 -11
View File
@@ -1,26 +1,25 @@
package types package types
type Status string type Status int
const ( const (
INITIALIZING Status = "INITIALIZING" INITIALIZING Status = iota
RUNNING Status = "RUNNING" RUNNING
SETUP Status = "SETUP"
) )
type Mode string type Mode int
const ( const (
INTERACTIVE Mode = "INTERACTIVE" INTERACTIVE Mode = iota
HEADLESS Mode = "HEADLESS" HEADLESS
) )
type TunnelType string type TunnelType int
const ( const (
UNKNOWN TunnelType = "UNKNOWN" UNKNOWN TunnelType = iota
HTTP TunnelType = "HTTP" HTTP
TCP TunnelType = "TCP" TCP
) )
type SessionKey struct { type SessionKey struct {