refactor: replace Get/Set patterns with idiomatic Go interfaces
Docker Build and Push / build-and-push-tags (push) Successful in 10m59s
Docker Build and Push / build-and-push-branches (push) Has been skipped

- rename constructors to New
- remove Get/Set-style accessors
- replace string-based enums with iota-backed types
This commit is contained in:
2026-01-14 15:28:17 +07:00
parent abd103b5ab
commit 07d9f3afe6
10 changed files with 231 additions and 214 deletions
+2 -2
View File
@@ -263,7 +263,7 @@ func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response")
}
userSession.GetInteraction().Redraw()
userSession.Interaction().Redraw()
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
@@ -321,7 +321,7 @@ func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto
}, "terminate session fetch failed")
}
if err = userSession.GetLifecycle().Close(); err != nil {
if err = userSession.Lifecycle().Close(); err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
+6 -6
View File
@@ -335,8 +335,8 @@ func (hs *httpServer) handler(conn net.Conn) {
return
}
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct {
channel ssh.Channel
@@ -346,7 +346,7 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
@@ -357,14 +357,14 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
channel = result.channel
reqs = result.reqs
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
@@ -390,6 +390,6 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
return
}
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return
}
+30 -30
View File
@@ -30,50 +30,50 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
return io.CopyBuffer(dst, src, buf)
}
type Forwarder struct {
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slugManager slug.Manager
slug slug.Slug
lifecycle Lifecycle
}
func NewForwarder(slugManager slug.Manager) *Forwarder {
return &Forwarder{
func New(slug slug.Slug) Forwarder {
return &forwarder{
listener: nil,
tunnelType: "",
tunnelType: types.UNKNOWN,
forwardedPort: 0,
slugManager: slugManager,
slug: slug,
lifecycle: nil,
}
}
type Lifecycle interface {
GetConnection() ssh.Conn
Connection() ssh.Conn
}
type ForwardingController interface {
AcceptTCPConnections()
type Forwarder interface {
SetType(tunnelType types.TunnelType)
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
SetLifecycle(lifecycle Lifecycle)
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
GetListener() net.Listener
Close() error
Listener() net.Listener
TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer)
AcceptTCPConnections()
Close() error
}
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
func (f *forwarder) SetLifecycle(lifecycle Lifecycle) {
f.lifecycle = lifecycle
}
func (f *Forwarder) AcceptTCPConnections() {
func (f *forwarder) AcceptTCPConnections() {
for {
conn, err := f.GetListener().Accept()
conn, err := f.Listener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
@@ -100,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() {
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
@@ -130,7 +130,7 @@ func (f *Forwarder) AcceptTCPConnections() {
}
}
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func() {
_, err := io.Copy(io.Discard, src)
if err != nil {
@@ -174,31 +174,31 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
wg.Wait()
}
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
func (f *forwarder) SetType(tunnelType types.TunnelType) {
f.tunnelType = tunnelType
}
func (f *Forwarder) GetTunnelType() types.TunnelType {
func (f *forwarder) TunnelType() types.TunnelType {
return f.tunnelType
}
func (f *Forwarder) GetForwardedPort() uint16 {
func (f *forwarder) ForwardedPort() uint16 {
return f.forwardedPort
}
func (f *Forwarder) SetForwardedPort(port uint16) {
func (f *forwarder) SetForwardedPort(port uint16) {
f.forwardedPort = port
}
func (f *Forwarder) SetListener(listener net.Listener) {
func (f *forwarder) SetListener(listener net.Listener) {
f.listener = listener
}
func (f *Forwarder) GetListener() net.Listener {
func (f *forwarder) Listener() net.Listener {
return f.listener
}
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse)
if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err)
@@ -206,20 +206,20 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
}
}
func (f *Forwarder) Close() error {
if f.GetListener() != nil {
func (f *forwarder) Close() error {
if f.Listener() != nil {
return f.listener.Close()
}
return nil
}
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
var buf bytes.Buffer
host, originPort := parseAddr(origin.String())
writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
+6 -6
View File
@@ -15,7 +15,7 @@ import (
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
@@ -56,7 +56,7 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
}
}
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
func (s *session) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected")
fail := func(msg string) {
@@ -103,7 +103,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
}
}
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
fail := func(msg string, key *types.SessionKey) {
log.Println(msg)
if key != nil {
@@ -137,11 +137,11 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(slug)
s.slug.Set(slug)
s.lifecycle.SetStatus(types.RUNNING)
}
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
@@ -219,7 +219,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
s.forwarder.SetType(types.TCP)
s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(key.Id)
s.slug.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections()
}
+31 -31
View File
@@ -23,34 +23,34 @@ import (
type Lifecycle interface {
Close() error
GetUser() string
User() string
}
type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error
}
type Controller interface {
type Interaction interface {
Mode() types.Mode
SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle)
Start()
SetWH(w, h int)
Redraw()
SetSessionRegistry(registry SessionRegistry)
SetMode(m types.Mode)
GetMode() types.Mode
SetWH(w, h int)
Start()
Redraw()
Send(message string) error
}
type Forwarder interface {
Close() error
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
TunnelType() types.TunnelType
ForwardedPort() uint16
}
type Interaction struct {
type interaction struct {
channel ssh.Channel
slugManager slug.Manager
slug slug.Slug
forwarder Forwarder
lifecycle Lifecycle
sessionRegistry SessionRegistry
@@ -60,22 +60,22 @@ type Interaction struct {
mode types.Mode
}
func (i *Interaction) SetMode(m types.Mode) {
func (i *interaction) SetMode(m types.Mode) {
i.mode = m
}
func (i *Interaction) GetMode() types.Mode {
func (i *interaction) Mode() types.Mode {
return i.mode
}
func (i *Interaction) Send(message string) error {
func (i *interaction) Send(message string) error {
if i.channel != nil {
_, err := i.channel.Write([]byte(message))
return err
}
return nil
}
func (i *Interaction) SetWH(w, h int) {
func (i *interaction) SetWH(w, h int) {
if i.program != nil {
i.program.Send(tea.WindowSizeMsg{
Width: w,
@@ -103,14 +103,14 @@ type model struct {
commandList list.Model
slugInput textinput.Model
slugError string
interaction *Interaction
interaction *interaction
width int
height int
}
func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP {
return buildURL(m.protocol, m.interaction.slugManager.Get(), m.domain)
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
}
@@ -123,11 +123,11 @@ type keymap struct {
type tickMsg time.Time
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
func New(slug slug.Slug, forwarder Forwarder) Interaction {
ctx, cancel := context.WithCancel(context.Background())
return &Interaction{
return &interaction{
channel: nil,
slugManager: slugManager,
slug: slug,
forwarder: forwarder,
lifecycle: nil,
sessionRegistry: nil,
@@ -137,19 +137,19 @@ func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction
}
}
func (i *Interaction) SetSessionRegistry(registry SessionRegistry) {
func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
i.sessionRegistry = registry
}
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle
}
func (i *Interaction) SetChannel(channel ssh.Channel) {
func (i *interaction) SetChannel(channel ssh.Channel) {
i.channel = channel
}
func (i *Interaction) Stop() {
func (i *interaction) Stop() {
if i.cancel != nil {
i.cancel()
}
@@ -242,8 +242,8 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{
Id: m.interaction.slugManager.Get(),
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
@@ -285,7 +285,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if item.name == "slug" {
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slugManager.Get())
m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" {
@@ -317,7 +317,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil
}
func (i *Interaction) Redraw() {
func (i *interaction) Redraw() {
if i.program != nil {
i.program.Send(tea.ClearScreen())
}
@@ -691,7 +691,7 @@ func (m *model) View() string {
MarginBottom(boxMargin).
Width(boxMaxWidth)
authenticatedUser := m.interaction.lifecycle.GetUser()
authenticatedUser := m.interaction.lifecycle.User()
userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
@@ -767,7 +767,7 @@ func (m *model) View() string {
return b.String()
}
func (i *Interaction) Start() {
func (i *interaction) Start() {
if i.mode == types.HEADLESS {
return
}
@@ -779,8 +779,8 @@ func (i *Interaction) Start() {
protocol = "https"
}
tunnelType := i.forwarder.GetTunnelType()
port := i.forwarder.GetForwardedPort()
tunnelType := i.forwarder.TunnelType()
port := i.forwarder.ForwardedPort()
items := []list.Item{
commandItem{name: "slug", desc: "Set custom subdomain"},
+44 -40
View File
@@ -15,115 +15,119 @@ import (
type Forwarder interface {
Close() error
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
TunnelType() types.TunnelType
ForwardedPort() uint16
}
type SessionRegistry interface {
Remove(key types.SessionKey)
}
type Lifecycle struct {
type lifecycle struct {
status types.Status
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
sessionRegistry SessionRegistry
slugManager slug.Manager
slug slug.Slug
startedAt time.Time
user string
}
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle {
return &Lifecycle{
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
return &lifecycle{
status: types.INITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
slugManager: slugManager,
slug: slugManager,
sessionRegistry: nil,
startedAt: time.Now(),
user: user,
}
}
func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) {
func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
l.sessionRegistry = registry
}
type SessionLifecycle interface {
Close() error
SetStatus(status types.Status)
GetConnection() ssh.Conn
GetChannel() ssh.Channel
GetUser() string
type Lifecycle interface {
Connection() ssh.Conn
Channel() ssh.Channel
User() string
SetChannel(channel ssh.Channel)
SetSessionRegistry(registry SessionRegistry)
SetStatus(status types.Status)
IsActive() bool
StartedAt() time.Time
Close() error
}
func (l *Lifecycle) GetUser() string {
func (l *lifecycle) User() string {
return l.user
}
func (l *Lifecycle) GetChannel() ssh.Channel {
func (l *lifecycle) Channel() ssh.Channel {
return l.channel
}
func (l *Lifecycle) SetChannel(channel ssh.Channel) {
func (l *lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel
}
func (l *Lifecycle) GetConnection() ssh.Conn {
func (l *lifecycle) Connection() ssh.Conn {
return l.conn
}
func (l *Lifecycle) SetStatus(status types.Status) {
func (l *lifecycle) SetStatus(status types.Status) {
l.status = status
if status == types.RUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now()
}
}
func (l *Lifecycle) Close() error {
err := l.forwarder.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
func (l *lifecycle) Close() error {
var firstErr error
tunnelType := l.forwarder.TunnelType()
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
firstErr = err
}
if l.channel != nil {
err := l.channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
return err
if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
if firstErr == nil {
firstErr = err
}
}
}
if l.conn != nil {
err := l.conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if firstErr == nil {
firstErr = err
}
}
}
clientSlug := l.slugManager.Get()
if clientSlug != "" && l.sessionRegistry.Remove != nil {
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()}
clientSlug := l.slug.String()
key := types.SessionKey{
Id: clientSlug,
Type: tunnelType,
}
l.sessionRegistry.Remove(key)
}
if l.forwarder.GetTunnelType() == types.TCP {
err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
if err != nil {
return err
if tunnelType == types.TCP {
if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
firstErr = err
}
}
return nil
return firstErr
}
func (l *Lifecycle) IsActive() bool {
func (l *lifecycle) IsActive() bool {
return l.status == types.RUNNING
}
func (l *Lifecycle) StartedAt() time.Time {
func (l *lifecycle) StartedAt() time.Time {
return l.startedAt
}
+16 -16
View File
@@ -9,27 +9,27 @@ import (
type Key = types.SessionKey
type Registry interface {
Get(key Key) (session *SSHSession, err error)
GetWithUser(user string, key Key) (session *SSHSession, err error)
Get(key Key) (session Session, err error)
GetWithUser(user string, key Key) (session Session, err error)
Update(user string, oldKey, newKey Key) error
Register(key Key, session *SSHSession) (success bool)
Register(key Key, session Session) (success bool)
Remove(key Key)
GetAllSessionFromUser(user string) []*SSHSession
GetAllSessionFromUser(user string) []Session
}
type registry struct {
mu sync.RWMutex
byUser map[string]map[Key]*SSHSession
byUser map[string]map[Key]Session
slugIndex map[Key]string
}
func NewRegistry() Registry {
return &registry{
byUser: make(map[string]map[Key]*SSHSession),
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
}
}
func (r *registry) Get(key Key) (session *SSHSession, err error) {
func (r *registry) Get(key Key) (session Session, err error) {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -45,7 +45,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) {
return client, nil
}
func (r *registry) GetWithUser(user string, key Key) (session *SSHSession, err error) {
func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -87,17 +87,17 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
delete(r.byUser[user], oldKey)
delete(r.slugIndex, oldKey)
client.slugManager.Set(newKey.Id)
client.Slug().Set(newKey.Id)
r.slugIndex[newKey] = user
if r.byUser[user] == nil {
r.byUser[user] = make(map[Key]*SSHSession)
r.byUser[user] = make(map[Key]Session)
}
r.byUser[user][newKey] = client
return nil
}
func (r *registry) Register(key Key, session *SSHSession) (success bool) {
func (r *registry) Register(key Key, session Session) (success bool) {
r.mu.Lock()
defer r.mu.Unlock()
@@ -105,9 +105,9 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
return false
}
userID := session.lifecycle.GetUser()
userID := session.Lifecycle().User()
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]*SSHSession)
r.byUser[userID] = make(map[Key]Session)
}
r.byUser[userID][key] = session
@@ -115,16 +115,16 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
return true
}
func (r *registry) GetAllSessionFromUser(user string) []*SSHSession {
func (r *registry) GetAllSessionFromUser(user string) []Session {
r.mu.RLock()
defer r.mu.RUnlock()
m := r.byUser[user]
if len(m) == 0 {
return []*SSHSession{}
return []Session{}
}
sessions := make([]*SSHSession, 0, len(m))
sessions := make([]Session, 0, len(m))
for _, s := range m {
sessions = append(sessions, s)
}
+79 -65
View File
@@ -14,61 +14,6 @@ import (
"golang.org/x/crypto/ssh"
)
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
}
type SSHSession struct {
initialReq <-chan *ssh.Request
sshReqChannel <-chan ssh.NewChannel
lifecycle lifecycle.SessionLifecycle
interaction interaction.Controller
forwarder forwarder.ForwardingController
slugManager slug.Manager
registry Registry
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
return s.lifecycle
}
func (s *SSHSession) GetInteraction() interaction.Controller {
return s.interaction
}
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
return s.forwarder
}
func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession {
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &SSHSession{
initialReq: forwardingReq,
sshReqChannel: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
}
}
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
@@ -77,21 +22,90 @@ type Detail struct {
StartedAt time.Time `json:"started_at,omitempty"`
}
func (s *SSHSession) Detail() Detail {
return Detail{
ForwardingType: string(s.forwarder.GetTunnelType()),
Slug: s.slugManager.Get(),
UserID: s.lifecycle.GetUser(),
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *Detail
Start() error
}
type session struct {
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle
interaction interaction.Interaction
forwarder forwarder.Forwarder
slug slug.Slug
registry Registry
}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(slugManager)
interactionManager := interaction.New(slugManager, forwarderManager)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &session{
initialReq: initialReq,
sshChan: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slug: slugManager,
registry: sessionRegistry,
}
}
func (s *session) Lifecycle() lifecycle.Lifecycle {
return s.lifecycle
}
func (s *session) Interaction() interaction.Interaction {
return s.interaction
}
func (s *session) Forwarder() forwarder.Forwarder {
return s.forwarder
}
func (s *session) Slug() slug.Slug {
return s.slug
}
func (s *session) Detail() *Detail {
var tunnelType string
if s.forwarder.TunnelType() == types.HTTP {
tunnelType = "HTTP"
} else if s.forwarder.TunnelType() == types.TCP {
tunnelType = "TCP"
} else {
tunnelType = "UNKNOWN"
}
return &Detail{
ForwardingType: tunnelType,
Slug: s.slug.String(),
UserID: s.lifecycle.User(),
Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(),
}
}
func (s *SSHSession) Start() error {
func (s *session) Start() error {
var channel ssh.NewChannel
var ok bool
select {
case channel, ok = <-s.sshReqChannel:
case channel, ok = <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
@@ -122,7 +136,7 @@ func (s *SSHSession) Start() error {
return fmt.Errorf("no forwarding Request")
}
if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.GetUser() == "UNAUTHORIZED" {
if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" {
if err := tcpipReq.Reply(false, nil); err != nil {
log.Printf("cannot reply to tcpip req: %s\n", err)
return err
@@ -137,7 +151,7 @@ func (s *SSHSession) Start() error {
s.HandleTCPIPForward(tcpipReq)
s.interaction.Start()
s.lifecycle.GetConnection().Wait()
s.lifecycle.Connection().Wait()
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
@@ -145,7 +159,7 @@ func (s *SSHSession) Start() error {
return nil
}
func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
func (s *session) waitForTCPIPForward() *ssh.Request {
select {
case req, ok := <-s.initialReq:
if !ok {
+7 -7
View File
@@ -1,24 +1,24 @@
package slug
type Manager interface {
Get() string
type Slug interface {
String() string
Set(slug string)
}
type manager struct {
type slug struct {
slug string
}
func NewManager() Manager {
return &manager{
func New() Slug {
return &slug{
slug: "",
}
}
func (s *manager) Get() string {
func (s *slug) String() string {
return s.slug
}
func (s *manager) Set(slug string) {
func (s *slug) Set(slug string) {
s.slug = slug
}
+10 -11
View File
@@ -1,26 +1,25 @@
package types
type Status string
type Status int
const (
INITIALIZING Status = "INITIALIZING"
RUNNING Status = "RUNNING"
SETUP Status = "SETUP"
INITIALIZING Status = iota
RUNNING
)
type Mode string
type Mode int
const (
INTERACTIVE Mode = "INTERACTIVE"
HEADLESS Mode = "HEADLESS"
INTERACTIVE Mode = iota
HEADLESS
)
type TunnelType string
type TunnelType int
const (
UNKNOWN TunnelType = "UNKNOWN"
HTTP TunnelType = "HTTP"
TCP TunnelType = "TCP"
UNKNOWN TunnelType = iota
HTTP
TCP
)
type SessionKey struct {