diff --git a/server/http.go b/server/http.go index d2ad943..a69b836 100644 --- a/server/http.go +++ b/server/http.go @@ -30,13 +30,13 @@ type CustomWriter struct { buf []byte respHeader *ResponseHeaderFactory reqHeader *RequestHeaderFactory - interaction interaction.InteractionController + interaction interaction.Controller respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware } -func (cw *CustomWriter) SetInteraction(interaction interaction.InteractionController) { +func (cw *CustomWriter) SetInteraction(interaction interaction.Controller) { cw.interaction = interaction } @@ -350,7 +350,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } } - sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) + sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) return } diff --git a/server/middleware.go b/server/middleware.go index d5f733b..ad8c546 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -29,11 +29,11 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [ } type RequestLogger struct { - interaction interaction.InteractionController + interaction interaction.Controller remoteAddr net.Addr } -func NewRequestLogger(interaction interaction.InteractionController, remoteAddr net.Addr) *RequestLogger { +func NewRequestLogger(interaction interaction.Controller, remoteAddr net.Addr) *RequestLogger { return &RequestLogger{ interaction: interaction, remoteAddr: remoteAddr, diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 82794ba..41c9602 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -1,9 +1,17 @@ package forwarder import ( + "bytes" + "encoding/binary" + "errors" + "io" + "log" "net" + "strconv" "tunnel_pls/session/slug" "tunnel_pls/types" + + "golang.org/x/crypto/ssh" ) type Forwarder struct { @@ -11,14 +19,83 @@ type Forwarder struct { TunnelType types.TunnelType ForwardedPort uint16 SlugManager slug.Manager + Lifecycle Lifecycle +} + +type Lifecycle interface { + GetConnection() ssh.Conn +} + +type ForwardingController interface { + AcceptTCPConnections() + SetType(tunnelType types.TunnelType) + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 + SetForwardedPort(port uint16) + SetListener(listener net.Listener) + GetListener() net.Listener + Close() error + HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) + SetLifecycle(lifecycle Lifecycle) +} + +func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { + f.Lifecycle = lifecycle } func (f *Forwarder) AcceptTCPConnections() { - panic("implement me") + for { + conn, err := f.GetListener().Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + log.Printf("Error accepting connection: %v", err) + continue + } + originHost, originPort := ParseAddr(conn.RemoteAddr().String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort()) + channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return + } + + go func() { + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } + } + }() + go f.HandleConnection(conn, channel, conn.RemoteAddr()) + } } -func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { - panic("implement me") +func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { + defer func(src ssh.Channel) { + err := src.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing connection: %v", err) + } + }(src) + log.Printf("Handling new forwarded connection from %s", remoteAddr) + + go func() { + _, err := io.Copy(src, dst) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + log.Printf("Error copying from conn.Reader to channel: %v", err) + } + }() + + _, err := io.Copy(dst, src) + + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from channel to conn.Writer: %v", err) + } + return } func (f *Forwarder) SetType(tunnelType types.TunnelType) { @@ -52,33 +129,39 @@ func (f *Forwarder) Close() error { return nil } -type ForwardingController interface { - AcceptTCPConnections() - UpdateClientSlug(oldSlug, newSlug string) bool - SetType(tunnelType types.TunnelType) - GetTunnelType() types.TunnelType - GetForwardedPort() uint16 - SetForwardedPort(port uint16) - SetListener(listener net.Listener) - GetListener() net.Listener - Close() error +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) +} +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) } -//func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { -// session.clientsMutex.Lock() -// defer session.clientsMutex.Unlock() -// -// if _, exists := session.Clients[newSlug]; exists && newSlug != oldSlug { -// return false -// } -// -// client, ok := session.Clients[oldSlug] -// if !ok { -// return false -// } -// -// delete(session.Clients, oldSlug) -// f.SlugManager.Set(newSlug) -// session.Clients[newSlug] = client -// return true -//} +func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { + var buf bytes.Buffer + + writeSSHString(&buf, "localhost") + err := binary.Write(&buf, binary.BigEndian, uint32(port)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + writeSSHString(&buf, host) + err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + + return buf.Bytes() +} diff --git a/session/handler.go b/session/handler.go index c807d65..9123310 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,12 +3,9 @@ package session import ( "bytes" "encoding/binary" - "errors" "fmt" - "io" "log" "net" - "strconv" portUtil "tunnel_pls/internal/port" "tunnel_pls/types" @@ -17,10 +14,7 @@ import ( "golang.org/x/crypto/ssh" ) -type UserConnection struct { - Reader io.Reader - Writer net.Conn -} +var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { for req := range GlobalRequest { @@ -157,23 +151,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { s.HandleTCPForward(req, addr, portToBind) } -var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} - -func isBlockedPort(port uint16) bool { - if port == 80 || port == 443 { - return false - } - if port < 1024 && port != 0 { - return true - } - for _, p := range blockedReservedPorts { - if p == port { - return true - } - } - return false -} - func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { s.Forwarder.SetType(types.HTTP) s.Forwarder.SetForwardedPort(portToBind) @@ -237,7 +214,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.Interaction.ShowWelcomeMessage() s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) - go s.acceptTCPConnections() + go s.Forwarder.AcceptTCPConnections() buf := new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) @@ -253,37 +230,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind } } -func (s *SSHSession) acceptTCPConnections() { - for { - conn, err := s.Forwarder.GetListener().Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - originHost, originPort := ParseAddr(conn.RemoteAddr().String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) - channel, reqs, err := s.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - return - } - - go func() { - for req := range reqs { - err := req.Reply(false, nil) - if err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } - }() - go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr()) - } -} - func generateUniqueSlug() string { maxAttempts := 5 @@ -303,30 +249,6 @@ func generateUniqueSlug() string { return "" } -func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { - defer func(src ssh.Channel) { - err := src.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing connection: %v", err) - } - }(src) - log.Printf("Handling new forwarded connection from %s", remoteAddr) - - go func() { - _, err := io.Copy(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying from conn.Reader to channel: %v", err) - } - }() - - _, err := io.Copy(dst, src) - - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error copying from channel to conn.Writer: %v", err) - } - return -} - func readSSHString(reader *bytes.Reader) (string, error) { var length uint32 if err := binary.Read(reader, binary.BigEndian, &length); err != nil { @@ -339,40 +261,17 @@ func readSSHString(reader *bytes.Reader) (string, error) { return string(strBytes), nil } -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { - var buf bytes.Buffer - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil +func isBlockedPort(port uint16) bool { + if port == 80 || port == 443 { + return false } - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil + if port < 1024 && port != 0 { + return true } - - return buf.Bytes() -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) + for _, p := range blockedReservedPorts { + if p == port { + return true + } + } + return false } diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index d9e65d6..0f6c3ca 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -14,22 +14,27 @@ import ( "golang.org/x/crypto/ssh" ) +var forbiddenSlug = []string{ + "ping", +} + type Lifecycle interface { Close() error } -type InteractionController interface { +type Controller interface { SendMessage(message string) HandleUserInput() - HandleCommand(command string, commandBuffer *bytes.Buffer) - HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) + HandleCommand(command string) + HandleSlugEditMode(connection ssh.Channel, char byte) HandleSlugSave(conn ssh.Channel) - HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) + HandleSlugCancel(connection ssh.Channel) HandleSlugUpdateError() ShowWelcomeMessage() DisplaySlugEditor() SetChannel(channel ssh.Channel) SetLifecycle(lifecycle Lifecycle) + SetSlugModificator(func(oldSlug, newSlug string) bool) } type Forwarder interface { @@ -39,13 +44,14 @@ type Forwarder interface { } type Interaction struct { - CommandBuffer *bytes.Buffer - EditMode bool - EditSlug string - channel ssh.Channel - SlugManager slug.Manager - Forwarder Forwarder - Lifecycle Lifecycle + CommandBuffer *bytes.Buffer + EditMode bool + EditSlug string + channel ssh.Channel + SlugManager slug.Manager + Forwarder Forwarder + Lifecycle Lifecycle + updateClientSlug func(oldSlug, newSlug string) bool } func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { @@ -67,7 +73,6 @@ func (i *Interaction) SendMessage(message string) { } func (i *Interaction) HandleUserInput() { - var commandBuffer bytes.Buffer buf := make([]byte, 1) i.EditMode = false @@ -84,42 +89,42 @@ func (i *Interaction) HandleUserInput() { char := buf[0] if i.EditMode { - i.HandleSlugEditMode(i.channel, char, &commandBuffer) + i.HandleSlugEditMode(i.channel, char) continue } i.SendMessage(string(buf[:n])) if char == 8 || char == 127 { - if commandBuffer.Len() > 0 { - commandBuffer.Truncate(commandBuffer.Len() - 1) + if i.CommandBuffer.Len() > 0 { + i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) i.SendMessage("\b \b") } continue } if char == '/' { - commandBuffer.Reset() - commandBuffer.WriteByte(char) + i.CommandBuffer.Reset() + i.CommandBuffer.WriteByte(char) continue } - if commandBuffer.Len() > 0 { + if i.CommandBuffer.Len() > 0 { if char == 13 { - i.HandleCommand(commandBuffer.String(), &commandBuffer) + i.HandleCommand(i.CommandBuffer.String()) continue } - commandBuffer.WriteByte(char) + i.CommandBuffer.WriteByte(char) } } } } -func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) { if char == 13 { i.HandleSlugSave(connection) } else if char == 27 { - i.HandleSlugCancel(connection, commandBuffer) + i.HandleSlugCancel(connection) } else if char == 8 || char == 127 { if len(i.EditSlug) > 0 { i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1] @@ -160,13 +165,13 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { return } if isValid { - //oldSlug := i.SlugManager.Get() + oldSlug := i.SlugManager.Get() newSlug := i.EditSlug - //if !i.updateClientSlug(oldSlug, newSlug) { - // i.HandleSlugUpdateError() - // return - //} + if !i.updateClientSlug(oldSlug, newSlug) { + i.HandleSlugUpdateError() + return + } _, err := connection.Write([]byte("\r\n\r\nāœ… SUBDOMAIN UPDATED āœ…\r\n\r\n")) if err != nil { @@ -251,7 +256,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { i.CommandBuffer.Reset() } -func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleSlugCancel(connection ssh.Channel) { i.EditMode = false _, err := connection.Write([]byte("\033[H\033[2J")) if err != nil { @@ -278,7 +283,7 @@ func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *by } i.ShowWelcomeMessage() - commandBuffer.Reset() + i.CommandBuffer.Reset() } func (i *Interaction) HandleSlugUpdateError() { @@ -296,7 +301,7 @@ func (i *Interaction) HandleSlugUpdateError() { } } -func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleCommand(command string) { switch command { case "/bye": i.SendMessage("\r\nClosing connection...") @@ -307,7 +312,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) } return case "/help": - i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug") + i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug\r\n") case "/clear": i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() @@ -323,7 +328,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) } case "/slug": if i.Forwarder.GetTunnelType() != types.HTTP { - i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))) + i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType())) } else { i.EditMode = true i.EditSlug = i.SlugManager.Get() @@ -335,7 +340,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) i.SendMessage("Unknown command") } - commandBuffer.Reset() + i.CommandBuffer.Reset() } func (i *Interaction) ShowWelcomeMessage() { @@ -401,6 +406,10 @@ func (i *Interaction) DisplaySlugEditor() { i.SendMessage("\r\n\r\n") } +func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) bool) { + i.updateClientSlug = modificator +} + func centerText(text string, width int) string { padding := (width - len(text)) / 2 if padding < 0 { @@ -408,6 +417,7 @@ func centerText(text string, width int) string { } return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding) } + func isValidSlug(slug string) bool { if len(slug) < 3 || len(slug) > 20 { return false @@ -436,10 +446,6 @@ func waitForKeyPress(connection ssh.Channel) { } } -var forbiddenSlug = []string{ - "ping", -} - func isForbiddenSlug(slug string) bool { for _, s := range forbiddenSlug { if slug == s { diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 2038c2a..29b02ed 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -7,6 +7,7 @@ import ( "log" "net" "time" + portUtil "tunnel_pls/internal/port" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -20,6 +21,7 @@ type Interaction interface { type Forwarder interface { Close() error GetTunnelType() types.TunnelType + GetForwardedPort() uint16 } type Lifecycle struct { @@ -27,9 +29,14 @@ type Lifecycle struct { Conn ssh.Conn Channel ssh.Channel - Interaction Interaction - Forwarder Forwarder - SlugManager slug.Manager + Interaction Interaction + Forwarder Forwarder + SlugManager slug.Manager + unregisterClient func(slug string) +} + +func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { + l.unregisterClient = unregisterClient } type SessionLifecycle interface { @@ -39,6 +46,7 @@ type SessionLifecycle interface { GetConnection() ssh.Conn GetChannel() ssh.Channel SetChannel(channel ssh.Channel) + SetUnregisterClient(unregisterClient func(slug string)) } func (l *Lifecycle) GetChannel() ssh.Channel { @@ -84,15 +92,9 @@ func (l *Lifecycle) WaitForRunningStatus() { func (l *Lifecycle) Close() error { err := l.Forwarder.Close() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { return err } - //if s.Forwarder.Listener != nil { - // err := s.Forwarder.Listener.Close() - // if err != nil && !errors.Is(err, net.ErrClosed) { - // return err - // } - //} if l.Channel != nil { err := l.Channel.Close() @@ -108,17 +110,17 @@ func (l *Lifecycle) Close() error { } } - //clientSlug := l.SlugManager.Get() - //if clientSlug != "" { - // unregisterClient(clientSlug) - //} + clientSlug := l.SlugManager.Get() + if clientSlug != "" { + l.unregisterClient(clientSlug) + } - //if l.Forwarder.GetType() == "TCP" && s.Forwarder.Listener != nil { - // err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) - // if err != nil { - // return err - // } - //} + if l.Forwarder.GetTunnelType() == types.TCP { + err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false) + if err != nil { + return err + } + } return nil } diff --git a/session/session.go b/session/session.go index b45b932..e122e38 100644 --- a/session/session.go +++ b/session/session.go @@ -1,6 +1,7 @@ package session import ( + "bytes" "log" "sync" "tunnel_pls/session/forwarder" @@ -12,11 +13,12 @@ import ( "golang.org/x/crypto/ssh" ) -type Session interface { - lifecycle.Lifecycle - interaction.InteractionController - forwarder.ForwardingController +var ( + clientsMutex sync.RWMutex + Clients = make(map[string]*SSHSession) +) +type Session interface { HandleGlobalRequest(ch <-chan *ssh.Request) HandleTCPIPForward(req *ssh.Request) HandleHTTPForward(req *ssh.Request, port uint16) @@ -25,7 +27,7 @@ type Session interface { type SSHSession struct { Lifecycle lifecycle.SessionLifecycle - Interaction interaction.InteractionController + Interaction interaction.Controller Forwarder forwarder.ForwardingController SlugManager slug.Manager } @@ -39,7 +41,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan SlugManager: slugManager, } interactionManager := &interaction.Interaction{ - CommandBuffer: nil, + CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), EditMode: false, EditSlug: "", SlugManager: slugManager, @@ -54,13 +56,18 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan Forwarder: forwarderManager, SlugManager: slugManager, } + + interactionManager.SetLifecycle(lifecycleManager) + interactionManager.SetSlugModificator(updateClientSlug) + forwarderManager.SetLifecycle(lifecycleManager) + lifecycleManager.SetUnregisterClient(unregisterClient) + session := &SSHSession{ Lifecycle: lifecycleManager, Interaction: interactionManager, Forwarder: forwarderManager, SlugManager: slugManager, } - interactionManager.SetLifecycle(lifecycleManager) go func() { go session.Lifecycle.WaitForRunningStatus() @@ -70,7 +77,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan if session.Lifecycle.GetChannel() == nil { session.Lifecycle.SetChannel(ch) session.Interaction.SetChannel(ch) - //session.Interaction.channel = ch session.Lifecycle.SetStatus(types.SETUP) go session.HandleGlobalRequest(forwardingReq) } @@ -84,10 +90,24 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan }() } -var ( - clientsMutex sync.RWMutex - Clients = make(map[string]*SSHSession) -) +func updateClientSlug(oldSlug, newSlug string) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { + return false + } + + client, ok := Clients[oldSlug] + if !ok { + return false + } + + delete(Clients, oldSlug) + client.SlugManager.Set(newSlug) + Clients[newSlug] = client + return true +} func registerClient(slug string, session *SSHSession) bool { clientsMutex.Lock()