diff --git a/main.go b/main.go index 8198f92..60a55f1 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "tunnel_pls/internal/config" "tunnel_pls/internal/key" "tunnel_pls/server" + "tunnel_pls/session" "tunnel_pls/version" "golang.org/x/crypto/ssh" @@ -58,6 +59,11 @@ func main() { } sshConfig.AddHostKey(private) - app := server.NewServer(sshConfig) + sessionRegistry := session.NewRegistry() + + app, err := server.NewServer(sshConfig, sessionRegistry) + if err != nil { + log.Fatalf("Failed to start server: %s", err) + } app.Start() } diff --git a/server/handler.go b/server/handler.go deleted file mode 100644 index 494e7f5..0000000 --- a/server/handler.go +++ /dev/null @@ -1,28 +0,0 @@ -package server - -import ( - "log" - "net" - "tunnel_pls/session" - - "golang.org/x/crypto/ssh" -) - -func (s *Server) handleConnection(conn net.Conn) { - sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) - if err != nil { - log.Printf("failed to establish SSH connection: %v", err) - err := conn.Close() - if err != nil { - log.Printf("failed to close SSH connection: %v", err) - return - } - return - } - - log.Println("SSH connection established:", sshConn.User()) - - session.New(sshConn, forwardingReqs, chans) - - return -} diff --git a/server/header.go b/server/header.go index ec0c224..584394b 100644 --- a/server/header.go +++ b/server/header.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "fmt" - "strings" ) type HeaderManager interface { @@ -44,43 +43,132 @@ type requestHeaderFactory struct { headers map[string]string } -func NewRequestHeaderFactory(br *bufio.Reader) (RequestHeaderManager, error) { +func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) { + switch v := r.(type) { + case []byte: + return parseHeadersFromBytes(v) + case *bufio.Reader: + return parseHeadersFromReader(v) + default: + return nil, fmt.Errorf("unsupported type: %T", r) + } +} + +func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) { header := &requestHeaderFactory{ - headers: make(map[string]string), + headers: make(map[string]string, 16), } - startLine, err := br.ReadString('\n') - if err != nil { - return nil, err + lineEnd := bytes.IndexByte(headerData, '\n') + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no newline found") } - startLine = strings.TrimRight(startLine, "\r\n") - header.startLine = []byte(startLine) - parts := strings.Split(startLine, " ") + startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n") + header.startLine = make([]byte, len(startLine)) + copy(header.startLine, startLine) + + parts := bytes.Split(startLine, []byte{' '}) if len(parts) < 3 { return nil, fmt.Errorf("invalid request line") } - header.method = parts[0] - header.path = parts[1] - header.version = parts[2] + header.method = string(parts[0]) + header.path = string(parts[1]) + header.version = string(parts[2]) - for { - line, err := br.ReadString('\n') - if err != nil { - return nil, err + remaining := headerData[lineEnd+1:] + + for len(remaining) > 0 { + lineEnd = bytes.IndexByte(remaining, '\n') + if lineEnd == -1 { + lineEnd = len(remaining) } - line = strings.TrimRight(line, "\r\n") - if line == "" { + line := bytes.TrimRight(remaining[:lineEnd], "\r\n") + + if len(line) == 0 { break } - kv := strings.SplitN(line, ":", 2) - if len(kv) != 2 { + colonIdx := bytes.IndexByte(line, ':') + if colonIdx != -1 { + key := bytes.TrimSpace(line[:colonIdx]) + value := bytes.TrimSpace(line[colonIdx+1:]) + header.headers[string(key)] = string(value) + } + + if lineEnd == len(remaining) { + break + } + remaining = remaining[lineEnd+1:] + } + + return header, nil +} + +func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) { + header := &requestHeaderFactory{ + headers: make(map[string]string, 16), + } + + startLineBytes, err := br.ReadSlice('\n') + if err != nil { + if err == bufio.ErrBufferFull { + var startLine string + startLine, err = br.ReadString('\n') + if err != nil { + return nil, err + } + startLineBytes = []byte(startLine) + } else { + return nil, err + } + } + + startLineBytes = bytes.TrimRight(startLineBytes, "\r\n") + header.startLine = make([]byte, len(startLineBytes)) + copy(header.startLine, startLineBytes) + + parts := bytes.Split(startLineBytes, []byte{' '}) + if len(parts) < 3 { + return nil, fmt.Errorf("invalid request line") + } + + header.method = string(parts[0]) + header.path = string(parts[1]) + header.version = string(parts[2]) + + for { + lineBytes, err := br.ReadSlice('\n') + if err != nil { + if err == bufio.ErrBufferFull { + var line string + line, err = br.ReadString('\n') + if err != nil { + return nil, err + } + lineBytes = []byte(line) + } else { + return nil, err + } + } + + lineBytes = bytes.TrimRight(lineBytes, "\r\n") + + if len(lineBytes) == 0 { + break + } + + colonIdx := bytes.IndexByte(lineBytes, ':') + if colonIdx == -1 { continue } - header.headers[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) + + key := bytes.TrimSpace(lineBytes[:colonIdx]) + value := bytes.TrimSpace(lineBytes[colonIdx+1:]) + + header.headers[string(key)] = string(value) } return header, nil diff --git a/server/http.go b/server/http.go index 9c2e506..433b9a0 100644 --- a/server/http.go +++ b/server/http.go @@ -17,15 +17,9 @@ import ( "golang.org/x/crypto/ssh" ) -type Interaction interface { - SendMessage(message string) -} - type HTTPWriter interface { io.Reader io.Writer - SetInteraction(interaction Interaction) - AddInteraction(interaction Interaction) GetRemoteAddr() net.Addr GetWriter() io.Writer AddResponseMiddleware(mw ResponseMiddleware) @@ -35,21 +29,16 @@ type HTTPWriter interface { } type customWriter struct { - remoteAddr net.Addr - writer io.Writer - reader io.Reader - headerBuf []byte - buf []byte - respHeader ResponseHeaderManager - reqHeader RequestHeaderManager - interaction Interaction - respMW []ResponseMiddleware - reqStartMW []RequestMiddleware - reqEndMW []RequestMiddleware -} - -func (cw *customWriter) SetInteraction(interaction Interaction) { - cw.interaction = interaction + remoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + respHeader ResponseHeaderManager + reqHeader RequestHeaderManager + respMW []ResponseMiddleware + reqStartMW []RequestMiddleware + reqEndMW []RequestMiddleware } func (cw *customWriter) GetRemoteAddr() net.Addr { @@ -110,8 +99,7 @@ func (cw *customWriter) Read(p []byte) (int, error) { } } - headerReader := bufio.NewReader(bytes.NewReader(header)) - reqhf, err := NewRequestHeaderFactory(headerReader) + reqhf, err := NewRequestHeaderFactory(header) if err != nil { return 0, err } @@ -135,11 +123,10 @@ func (cw *customWriter) Read(p []byte) (int, error) { func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { return &customWriter{ - remoteAddr: remoteAddr, - writer: writer, - reader: reader, - buf: make([]byte, 0, 4096), - interaction: nil, + remoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), } } @@ -224,13 +211,23 @@ func (cw *customWriter) Write(p []byte) (int, error) { return len(p), nil } -func (cw *customWriter) AddInteraction(interaction Interaction) { - cw.interaction = interaction -} - var redirectTLS = false -func NewHTTPServer() error { +type HTTPServer interface { + ListenAndServe() error + ListenAndServeTLS() error + handler(conn net.Conn) + handlerTLS(conn net.Conn) +} +type httpServer struct { + sessionRegistry session.Registry +} + +func NewHTTPServer(sessionRegistry session.Registry) HTTPServer { + return &httpServer{sessionRegistry: sessionRegistry} +} + +func (hs *httpServer) ListenAndServe() error { httpPort := config.Getenv("HTTP_PORT", "8080") listener, err := net.Listen("tcp", ":"+httpPort) if err != nil { @@ -251,13 +248,13 @@ func NewHTTPServer() error { continue } - go Handler(conn) + go hs.handler(conn) } }() return nil } -func Handler(conn net.Conn) { +func (hs *httpServer) handler(conn net.Conn) { defer func() { err := conn.Close() if err != nil && !errors.Is(err, net.ErrClosed) { @@ -316,8 +313,8 @@ func Handler(conn net.Conn) { return } - sshSession, ok := session.Clients[slug] - if !ok { + sshSession, exist := hs.sessionRegistry.Get(slug) + if !exist { _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + "Content-Length: 0\r\n" + diff --git a/server/https.go b/server/https.go index 55849cf..90ffd49 100644 --- a/server/https.go +++ b/server/https.go @@ -9,10 +9,9 @@ import ( "net" "strings" "tunnel_pls/internal/config" - "tunnel_pls/session" ) -func NewHTTPSServer() error { +func (hs *httpServer) ListenAndServeTLS() error { domain := config.Getenv("DOMAIN", "localhost") httpsPort := config.Getenv("HTTPS_PORT", "8443") @@ -38,13 +37,13 @@ func NewHTTPSServer() error { continue } - go HandlerTLS(conn) + go hs.handlerTLS(conn) } }() return nil } -func HandlerTLS(conn net.Conn) { +func (hs *httpServer) handlerTLS(conn net.Conn) { defer func() { err := conn.Close() if err != nil { @@ -90,8 +89,8 @@ func HandlerTLS(conn net.Conn) { return } - sshSession, ok := session.Clients[slug] - if !ok { + sshSession, exist := hs.sessionRegistry.Get(slug) + if !exist { _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + "Content-Length: 0\r\n" + diff --git a/server/server.go b/server/server.go index 531b3d7..2b9fda4 100644 --- a/server/server.go +++ b/server/server.go @@ -4,50 +4,45 @@ import ( "fmt" "log" "net" - "net/http" "tunnel_pls/internal/config" + "tunnel_pls/session" "golang.org/x/crypto/ssh" ) type Server struct { - conn *net.Listener - config *ssh.ServerConfig - httpServer *http.Server + conn *net.Listener + config *ssh.ServerConfig + sessionRegistry session.Registry } -func (s *Server) GetConn() *net.Listener { - return s.conn -} - -func (s *Server) GetConfig() *ssh.ServerConfig { - return s.config -} - -func (s *Server) GetHttpServer() *http.Server { - return s.httpServer -} - -func NewServer(sshConfig *ssh.ServerConfig) *Server { +func NewServer(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry) (*Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) - return nil + return nil, err } - if config.Getenv("TLS_ENABLED", "false") == "true" { - err = NewHTTPSServer() - if err != nil { - log.Fatalf("failed to start https server: %v", err) - } - } - err = NewHTTPServer() + + HttpServer := NewHTTPServer(sessionRegistry) + err = HttpServer.ListenAndServe() if err != nil { log.Fatalf("failed to start http server: %v", err) + return nil, err } + + if config.Getenv("TLS_ENABLED", "false") == "true" { + err = HttpServer.ListenAndServeTLS() + if err != nil { + log.Fatalf("failed to start https server: %v", err) + return nil, err + } + } + return &Server{ - conn: &listener, - config: sshConfig, - } + conn: &listener, + config: sshConfig, + sessionRegistry: sessionRegistry, + }, nil } func (s *Server) Start() { @@ -62,3 +57,26 @@ func (s *Server) Start() { go s.handleConnection(conn) } } + +func (s *Server) handleConnection(conn net.Conn) { + sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) + if err != nil { + log.Printf("failed to establish SSH connection: %v", err) + err := conn.Close() + if err != nil { + log.Printf("failed to close SSH connection: %v", err) + return + } + return + } + + log.Println("SSH connection established:", sshConn.User()) + + sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry) + err = sshSession.Start() + if err != nil { + log.Printf("SSH session ended with error: %v", err) + return + } + return +} diff --git a/server/tls.go b/server/tls.go index 5933026..8cc8afe 100644 --- a/server/tls.go +++ b/server/tls.go @@ -301,7 +301,22 @@ func (tm *tlsManager) initCertMagic() error { func (tm *tlsManager) getTLSConfig() *tls.Config { return &tls.Config{ GetCertificate: tm.getCertificate, - MinVersion: tls.VersionTLS12, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + + SessionTicketsDisabled: false, + + CipherSuites: []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + + CurvePreferences: []tls.CurveID{ + tls.X25519, + }, + + ClientAuth: tls.NoClientCert, + NextProtos: nil, } } diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 4558533..3d32a43 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -152,25 +152,26 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA log.Printf("Handling new forwarded connection from %s", remoteAddr) - done := make(chan struct{}, 2) - - go func() { - _, err := copyWithBuffer(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying from conn.Reader to channel: %v", err) - } - done <- struct{}{} - }() + var wg sync.WaitGroup + wg.Add(2) go func() { + defer wg.Done() _, err := copyWithBuffer(dst, src) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying from channel to conn.Writer: %v", err) + log.Printf("Error copying src→dst: %v", err) } - done <- struct{}{} }() - <-done + go func() { + defer wg.Done() + _, err := copyWithBuffer(src, dst) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + log.Printf("Error copying dst→src: %v", err) + } + }() + + wg.Wait() } func (f *Forwarder) SetType(tunnelType types.TunnelType) { diff --git a/session/handler.go b/session/handler.go index d4c808c..30458fb 100644 --- a/session/handler.go +++ b/session/handler.go @@ -106,7 +106,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { } portToBind := uint16(rawPortToBind) - if isBlockedPort(portToBind) { log.Printf("Port %d is blocked or restricted", portToBind) err := req.Reply(false, nil) @@ -164,16 +163,9 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { } func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { - slug := generateUniqueSlug() - if slug == "" { - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - } - return - } + slug := random.GenerateRandomString(20) - if !registerClient(slug, s) { + if !s.registry.Register(slug, s) { log.Printf("Failed to register client with slug: %s", slug) err := req.Reply(false, nil) if err != nil { @@ -186,7 +178,7 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { log.Println("Failed to write port to buffer:", err) - unregisterClient(slug) + s.registry.Remove(slug) err = req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -198,7 +190,7 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) - unregisterClient(slug) + s.registry.Remove(slug) err = req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -271,25 +263,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.interaction.Start() } -func generateUniqueSlug() string { - maxAttempts := 5 - - for i := 0; i < maxAttempts; i++ { - slug := random.GenerateRandomString(20) - - clientsMutex.RLock() - _, exists := Clients[slug] - clientsMutex.RUnlock() - - if !exists { - return slug - } - } - - log.Println("Failed to generate unique slug after multiple attempts") - return "" -} - func readSSHString(reader *bytes.Reader) (string, error) { var length uint32 if err := binary.Read(reader, binary.BigEndian, &length); err != nil { diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 93d6060..2b24e60 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -114,7 +114,7 @@ func (i *Interaction) SetChannel(channel ssh.Channel) { i.channel = channel } -func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) bool) { +func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) (success bool)) { i.updateClientSlug = modificator } @@ -199,6 +199,11 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } if m.editingSlug { + if m.tunnelType != types.HTTP { + m.editingSlug = false + m.slugError = "" + return m, tea.Batch(tea.ClearScreen, textinput.Blink) + } switch msg.String() { case "esc": m.editingSlug = false diff --git a/session/registry.go b/session/registry.go new file mode 100644 index 0000000..cc1e955 --- /dev/null +++ b/session/registry.go @@ -0,0 +1,66 @@ +package session + +import "sync" + +type Registry interface { + Get(slug string) (session *SSHSession, exist bool) + Update(oldSlug, newSlug string) (success bool) + Register(slug string, session *SSHSession) (success bool) + Remove(slug string) +} +type registry struct { + mu sync.RWMutex + clients map[string]*SSHSession +} + +func NewRegistry() Registry { + return ®istry{ + clients: make(map[string]*SSHSession), + } +} + +func (r *registry) Get(slug string) (session *SSHSession, exist bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + session, exist = r.clients[slug] + return +} + +func (r *registry) Update(oldSlug, newSlug string) (success bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.clients[newSlug]; exists && newSlug != oldSlug { + return false + } + + client, ok := r.clients[oldSlug] + if !ok { + return false + } + + delete(r.clients, oldSlug) + client.slugManager.Set(newSlug) + r.clients[newSlug] = client + return true +} + +func (r *registry) Register(slug string, session *SSHSession) (success bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.clients[slug]; exists { + return false + } + + r.clients[slug] = session + return true +} + +func (r *registry) Remove(slug string) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.clients, slug) +} diff --git a/session/session.go b/session/session.go index db5fc27..9a35770 100644 --- a/session/session.go +++ b/session/session.go @@ -1,8 +1,8 @@ package session import ( + "fmt" "log" - "sync" "time" "tunnel_pls/internal/config" "tunnel_pls/session/forwarder" @@ -13,11 +13,6 @@ import ( "golang.org/x/crypto/ssh" ) -var ( - clientsMutex sync.RWMutex - Clients = make(map[string]*SSHSession) -) - type Session interface { HandleGlobalRequest(ch <-chan *ssh.Request) HandleTCPIPForward(req *ssh.Request) @@ -26,10 +21,13 @@ type Session interface { } type SSHSession struct { - lifecycle lifecycle.SessionLifecycle - interaction interaction.Controller - forwarder forwarder.ForwardingController - slugManager slug.Manager + initialReq <-chan *ssh.Request + sshReqChannel <-chan ssh.NewChannel + lifecycle lifecycle.SessionLifecycle + interaction interaction.Controller + forwarder forwarder.ForwardingController + slugManager slug.Manager + registry Registry } func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { @@ -48,55 +46,64 @@ func (s *SSHSession) GetSlugManager() slug.Manager { return s.slugManager } -func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { +func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry) *SSHSession { slugManager := slug.NewManager() forwarderManager := forwarder.NewForwarder(slugManager) interactionManager := interaction.NewInteraction(slugManager, forwarderManager) lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager) interactionManager.SetLifecycle(lifecycleManager) - interactionManager.SetSlugModificator(updateClientSlug) + interactionManager.SetSlugModificator(sessionRegistry.Update) forwarderManager.SetLifecycle(lifecycleManager) - lifecycleManager.SetUnregisterClient(unregisterClient) + lifecycleManager.SetUnregisterClient(sessionRegistry.Remove) - session := &SSHSession{ - lifecycle: lifecycleManager, - interaction: interactionManager, - forwarder: forwarderManager, - slugManager: slugManager, - } - - var once sync.Once - for channel := range sshChan { - ch, reqs, err := channel.Accept() - if err != nil { - log.Printf("failed to accept channel: %v", err) - continue - } - once.Do(func() { - session.lifecycle.SetChannel(ch) - session.interaction.SetChannel(ch) - - tcpipReq := session.waitForTCPIPForward(forwardingReq) - if tcpipReq == nil { - log.Printf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")) - if err := session.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - return - } - go session.HandleTCPIPForward(tcpipReq) - }) - session.HandleGlobalRequest(reqs) - } - if err := session.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) + return &SSHSession{ + initialReq: forwardingReq, + sshReqChannel: sshChan, + lifecycle: lifecycleManager, + interaction: interactionManager, + forwarder: forwarderManager, + slugManager: slugManager, + registry: sessionRegistry, } } -func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh.Request { +func (s *SSHSession) Start() error { + channel := <-s.sshReqChannel + ch, reqs, err := channel.Accept() + if err != nil { + log.Printf("failed to accept channel: %v", err) + return err + } + go s.HandleGlobalRequest(reqs) + + tcpipReq := s.waitForTCPIPForward() + if tcpipReq == nil { + _, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))) + if err != nil { + return err + } + if err := s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + } + return fmt.Errorf("No forwarding Request") + } + + s.lifecycle.SetChannel(ch) + s.interaction.SetChannel(ch) + + s.HandleTCPIPForward(tcpipReq) + + if err := s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + return err + } + return nil +} + +func (s *SSHSession) waitForTCPIPForward() *ssh.Request { select { - case req, ok := <-forwardingReq: + case req, ok := <-s.initialReq: if !ok { log.Println("Forwarding request channel closed") return nil @@ -114,41 +121,3 @@ func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh return nil } } - -func updateClientSlug(oldSlug, newSlug string) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { - return false - } - - client, ok := Clients[oldSlug] - if !ok { - return false - } - - delete(Clients, oldSlug) - client.slugManager.Set(newSlug) - Clients[newSlug] = client - return true -} - -func registerClient(slug string, session *SSHSession) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[slug]; exists { - return false - } - - Clients[slug] = session - return true -} - -func unregisterClient(slug string) { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - delete(Clients, slug) -} diff --git a/session/slug/slug.go b/session/slug/slug.go index 4900e22..7ab4697 100644 --- a/session/slug/slug.go +++ b/session/slug/slug.go @@ -1,32 +1,24 @@ package slug -import "sync" - type Manager interface { Get() string Set(slug string) } type manager struct { - slug string - slugMu sync.RWMutex + slug string } func NewManager() Manager { return &manager{ - slug: "", - slugMu: sync.RWMutex{}, + slug: "", } } func (s *manager) Get() string { - s.slugMu.RLock() - defer s.slugMu.RUnlock() return s.slug } func (s *manager) Set(slug string) { - s.slugMu.Lock() s.slug = slug - s.slugMu.Unlock() }