From 039e97914260737493bce1cb39854b1f1bc99be1 Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 4 Dec 2025 19:32:00 +0700 Subject: [PATCH 1/7] refactor: restructure session initialization to avoid circular references --- server/http.go | 14 +- server/middleware.go | 8 +- session/forwarder.go | 37 ----- session/forwarder/forwarder.go | 84 +++++++++++ session/handler.go | 181 +++-------------------- session/{ => interaction}/interaction.go | 127 ++++++++++------ session/lifecycle/lifecycle.go | 124 ++++++++++++++++ session/session.go | 155 +++++++++---------- session/slug/slug.go | 32 ++++ types/types.go | 16 ++ 10 files changed, 448 insertions(+), 330 deletions(-) delete mode 100644 session/forwarder.go create mode 100644 session/forwarder/forwarder.go rename session/{ => interaction}/interaction.go (83%) create mode 100644 session/lifecycle/lifecycle.go create mode 100644 session/slug/slug.go create mode 100644 types/types.go diff --git a/server/http.go b/server/http.go index 3f1aaba..d2ad943 100644 --- a/server/http.go +++ b/server/http.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "tunnel_pls/session" + "tunnel_pls/session/interaction" "tunnel_pls/utils" "golang.org/x/crypto/ssh" @@ -29,12 +30,16 @@ type CustomWriter struct { buf []byte respHeader *ResponseHeaderFactory reqHeader *RequestHeaderFactory - interaction *session.Interaction + interaction interaction.InteractionController respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware } +func (cw *CustomWriter) SetInteraction(interaction interaction.InteractionController) { + cw.interaction = interaction +} + func (cw *CustomWriter) Read(p []byte) (int, error) { tmp := make([]byte, len(p)) read, err := cw.reader.Read(tmp) @@ -177,7 +182,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return n, nil } -func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { +func (cw *CustomWriter) AddInteraction(interaction *interaction.Interaction) { cw.interaction = interaction } @@ -287,16 +292,15 @@ func Handler(conn net.Conn) { return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - + cw.SetInteraction(sshSession.Interaction) forwardRequest(cw, reqhf, sshSession) return } func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { - cw.AddInteraction(sshSession.Interaction) originHost, originPort := ParseAddr(cw.RemoteAddr.String()) payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) - channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) sendBadGatewayResponse(cw) diff --git a/server/middleware.go b/server/middleware.go index 08ee035..d5f733b 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -4,7 +4,7 @@ import ( "fmt" "net" "time" - "tunnel_pls/session" + "tunnel_pls/session/interaction" ) type RequestMiddleware interface { @@ -29,13 +29,13 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [ } type RequestLogger struct { - interaction session.Interaction + interaction interaction.InteractionController remoteAddr net.Addr } -func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger { +func NewRequestLogger(interaction interaction.InteractionController, remoteAddr net.Addr) *RequestLogger { return &RequestLogger{ - interaction: *interaction, + interaction: interaction, remoteAddr: remoteAddr, } } diff --git a/session/forwarder.go b/session/forwarder.go deleted file mode 100644 index e7abc17..0000000 --- a/session/forwarder.go +++ /dev/null @@ -1,37 +0,0 @@ -package session - -import ( - "net" - - "golang.org/x/crypto/ssh" -) - -type Forwarder struct { - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - - getSlug func() string - setSlug func(string) -} - -type ForwardingController interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) - AcceptTCPConnections() -} - -type ForwarderInfo interface { - GetTunnelType() TunnelType - GetForwardedPort() uint16 -} - -func (f *Forwarder) GetTunnelType() TunnelType { - return f.TunnelType -} - -func (f *Forwarder) GetForwardedPort() uint16 { - return f.ForwardedPort -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go new file mode 100644 index 0000000..82794ba --- /dev/null +++ b/session/forwarder/forwarder.go @@ -0,0 +1,84 @@ +package forwarder + +import ( + "net" + "tunnel_pls/session/slug" + "tunnel_pls/types" +) + +type Forwarder struct { + Listener net.Listener + TunnelType types.TunnelType + ForwardedPort uint16 + SlugManager slug.Manager +} + +func (f *Forwarder) AcceptTCPConnections() { + panic("implement me") +} + +func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { + panic("implement me") +} + +func (f *Forwarder) SetType(tunnelType types.TunnelType) { + f.TunnelType = tunnelType +} + +func (f *Forwarder) GetTunnelType() types.TunnelType { + return f.TunnelType +} + +func (f *Forwarder) GetForwardedPort() uint16 { + return f.ForwardedPort +} + +func (f *Forwarder) SetForwardedPort(port uint16) { + f.ForwardedPort = port +} + +func (f *Forwarder) SetListener(listener net.Listener) { + f.Listener = listener +} + +func (f *Forwarder) GetListener() net.Listener { + return f.Listener +} + +func (f *Forwarder) Close() error { + if f.GetTunnelType() != types.HTTP { + return f.Listener.Close() + } + return nil +} + +type ForwardingController interface { + AcceptTCPConnections() + UpdateClientSlug(oldSlug, newSlug string) bool + SetType(tunnelType types.TunnelType) + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 + SetForwardedPort(port uint16) + SetListener(listener net.Listener) + GetListener() net.Listener + Close() error +} + +//func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { +// session.clientsMutex.Lock() +// defer session.clientsMutex.Unlock() +// +// if _, exists := session.Clients[newSlug]; exists && newSlug != oldSlug { +// return false +// } +// +// client, ok := session.Clients[oldSlug] +// if !ok { +// return false +// } +// +// delete(session.Clients, oldSlug) +// f.SlugManager.Set(newSlug) +// session.Clients[newSlug] = client +// return true +//} diff --git a/session/handler.go b/session/handler.go index 5c63338..c807d65 100644 --- a/session/handler.go +++ b/session/handler.go @@ -9,92 +9,24 @@ import ( "log" "net" "strconv" - "sync" - "time" portUtil "tunnel_pls/internal/port" + "tunnel_pls/types" "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) -type Status string - -var forbiddenSlug = []string{ - "ping", -} - type UserConnection struct { Reader io.Reader Writer net.Conn } -var ( - clientsMutex sync.RWMutex - Clients = make(map[string]*SSHSession) -) - -func registerClient(slug string, session *SSHSession) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[slug]; exists { - return false - } - - Clients[slug] = session - return true -} - -func unregisterClient(slug string) { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - delete(Clients, slug) -} - -func (s *SSHSession) Close() error { - if s.Forwarder.Listener != nil { - err := s.Forwarder.Listener.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err - } - } - - if s.channel != nil { - err := s.channel.Close() - if err != nil && !errors.Is(err, io.EOF) { - return err - } - } - - if s.Conn != nil { - err := s.Conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err - } - } - - slug := s.Forwarder.getSlug() - if slug != "" { - unregisterClient(slug) - } - - if s.Forwarder.TunnelType == TCP && s.Forwarder.Listener != nil { - err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) - if err != nil { - return err - } - } - - return nil -} - func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { for req := range GlobalRequest { switch req.Type { case "tcpip-forward": - s.handleTCPIPForward(req) + s.HandleTCPIPForward(req) return case "shell", "pty-req", "window-change": err := req.Reply(true, nil) @@ -113,7 +45,7 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { } } -func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { +func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { log.Println("Port forwarding request detected") reader := bytes.NewReader(req.Payload) @@ -126,7 +58,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -142,7 +74,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -156,7 +88,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -172,7 +104,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -180,11 +112,11 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { } s.Interaction.SendMessage("\033[H\033[2J") - s.Lifecycle.Status = RUNNING + s.Lifecycle.SetStatus(types.RUNNING) go s.Interaction.HandleUserInput() if portToBind == 80 || portToBind == 443 { - s.handleHTTPForward(req, portToBind) + s.HandleHTTPForward(req, portToBind) return } else { if portToBind == 0 { @@ -197,7 +129,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -210,7 +142,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -222,7 +154,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } } - s.handleTCPForward(req, addr, portToBind) + s.HandleTCPForward(req, addr, portToBind) } var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} @@ -242,9 +174,9 @@ func isBlockedPort(port uint16) bool { return false } -func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { - s.Forwarder.TunnelType = HTTP - s.Forwarder.ForwardedPort = portToBind +func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { + s.Forwarder.SetType(types.HTTP) + s.Forwarder.SetForwardedPort(portToBind) slug := generateUniqueSlug() if slug == "" { @@ -256,7 +188,7 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.Forwarder.setSlug(slug) + s.SlugManager.Set(slug) registerClient(slug, s) buf := new(bytes.Buffer) @@ -282,8 +214,8 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { } } -func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - s.Forwarder.TunnelType = TCP +func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { + s.Forwarder.SetType(types.TCP) log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) @@ -294,16 +226,16 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } return } - s.Forwarder.Listener = listener - s.Forwarder.ForwardedPort = portToBind + s.Forwarder.SetListener(listener) + s.Forwarder.SetForwardedPort(portToBind) s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.TunnelType, utils.Getenv("domain"), s.Forwarder.ForwardedPort)) + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) go s.acceptTCPConnections() @@ -323,7 +255,7 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind func (s *SSHSession) acceptTCPConnections() { for { - conn, err := s.Forwarder.Listener.Accept() + conn, err := s.Forwarder.GetListener().Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -333,7 +265,7 @@ func (s *SSHSession) acceptTCPConnections() { } originHost, originPort := ParseAddr(conn.RemoteAddr().String()) payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) - channel, reqs, err := s.Conn.OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := s.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) return @@ -371,71 +303,6 @@ func generateUniqueSlug() string { return "" } -func (s *SSHSession) waitForRunningStatus() { - timeout := time.After(3 * time.Second) - ticker := time.NewTicker(150 * time.Millisecond) - defer ticker.Stop() - frames := []string{"-", "\\", "|", "/"} - i := 0 - for { - select { - case <-ticker.C: - s.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) - i = (i + 1) % len(frames) - if s.Lifecycle.Status == RUNNING { - s.Interaction.SendMessage("\r\033[K") - return - } - case <-timeout: - s.Interaction.SendMessage("\r\033[K") - s.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") - err := s.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } - log.Println("Timeout waiting for session to start running") - return - } - } -} - -func isForbiddenSlug(slug string) bool { - for _, s := range forbiddenSlug { - if slug == s { - return true - } - } - return false -} - -func isValidSlug(slug string) bool { - if len(slug) < 3 || len(slug) > 20 { - return false - } - - if slug[0] == '-' || slug[len(slug)-1] == '-' { - return false - } - - for _, c := range slug { - if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { - return false - } - } - - return true -} - -func waitForKeyPress(connection ssh.Channel) { - keyBuf := make([]byte, 1) - for { - _, err := connection.Read(keyBuf) - if err == nil { - break - } - } -} - func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { defer func(src ssh.Channel) { err := src.Close() diff --git a/session/interaction.go b/session/interaction/interaction.go similarity index 83% rename from session/interaction.go rename to session/interaction/interaction.go index cfa1ce1..d9e65d6 100644 --- a/session/interaction.go +++ b/session/interaction/interaction.go @@ -1,4 +1,4 @@ -package session +package interaction import ( "bytes" @@ -7,21 +7,35 @@ import ( "log" "strings" "time" + "tunnel_pls/session/slug" + "tunnel_pls/types" "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) +type Lifecycle interface { + Close() error +} + type InteractionController interface { SendMessage(message string) HandleUserInput() - HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) - HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) + HandleCommand(command string, commandBuffer *bytes.Buffer) + HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) + HandleSlugSave(conn ssh.Channel) + HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) HandleSlugUpdateError() ShowWelcomeMessage() DisplaySlugEditor() + SetChannel(channel ssh.Channel) + SetLifecycle(lifecycle Lifecycle) +} + +type Forwarder interface { + Close() error + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 } type Interaction struct { @@ -29,13 +43,17 @@ type Interaction struct { EditMode bool EditSlug string channel ssh.Channel + SlugManager slug.Manager + Forwarder Forwarder + Lifecycle Lifecycle +} - getSlug func() string - setSlug func(string) +func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { + i.Lifecycle = lifecycle +} - session SessionCloser - - forwarder ForwarderInfo +func (i *Interaction) SetChannel(channel ssh.Channel) { + i.channel = channel } func (i *Interaction) SendMessage(message string) { @@ -142,13 +160,13 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { return } if isValid { - oldSlug := i.getSlug() + //oldSlug := i.SlugManager.Get() newSlug := i.EditSlug - if !updateClientSlug(oldSlug, newSlug) { - i.HandleSlugUpdateError() - return - } + //if !i.updateClientSlug(oldSlug, newSlug) { + // i.HandleSlugUpdateError() + // return + //} _, err := connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) if err != nil { @@ -223,7 +241,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { if utils.Getenv("tls_enabled") == "true" { protocol = "https" } - _, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain))) + _, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))) if err != nil { log.Printf("failed to write to channel: %v", err) return @@ -271,7 +289,7 @@ func (i *Interaction) HandleSlugUpdateError() { i.SendMessage(fmt.Sprintf("Disconnecting in %d...\r\n", iter)) time.Sleep(1 * time.Second) } - err := i.session.Close() + err := i.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) return @@ -282,7 +300,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) switch command { case "/bye": i.SendMessage("\r\nClosing connection...") - err := i.session.Close() + err := i.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) return @@ -294,21 +312,21 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() domain := utils.Getenv("domain") - if i.forwarder.GetTunnelType() == HTTP { + if i.Forwarder.GetTunnelType() == types.HTTP { protocol := "http" if utils.Getenv("tls_enabled") == "true" { protocol = "https" } - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain)) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) } else { - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.forwarder.GetTunnelType(), domain, i.forwarder.GetForwardedPort())) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.Forwarder.GetTunnelType(), domain, i.Forwarder.GetForwardedPort())) } case "/slug": - if i.forwarder.GetTunnelType() != HTTP { - i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.forwarder.GetTunnelType()))) + if i.Forwarder.GetTunnelType() != types.HTTP { + i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))) } else { i.EditMode = true - i.EditSlug = i.getSlug() + i.EditSlug = i.SlugManager.Get() i.SendMessage("\033[H\033[2J") i.DisplaySlugEditor() i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain")) @@ -347,7 +365,7 @@ func (i *Interaction) ShowWelcomeMessage() { func (i *Interaction) DisplaySlugEditor() { domain := utils.Getenv("domain") - fullDomain := i.getSlug() + "." + domain + fullDomain := i.SlugManager.Get() + "." + domain const paddingRight = 4 @@ -383,25 +401,6 @@ func (i *Interaction) DisplaySlugEditor() { i.SendMessage("\r\n\r\n") } -func updateClientSlug(oldSlug, newSlug string) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { - return false - } - - client, ok := Clients[oldSlug] - if !ok { - return false - } - - delete(Clients, oldSlug) - client.Forwarder.setSlug(newSlug) - Clients[newSlug] = client - return true -} - func centerText(text string, width int) string { padding := (width - len(text)) / 2 if padding < 0 { @@ -409,3 +408,43 @@ func centerText(text string, width int) string { } return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding) } +func isValidSlug(slug string) bool { + if len(slug) < 3 || len(slug) > 20 { + return false + } + + if slug[0] == '-' || slug[len(slug)-1] == '-' { + return false + } + + for _, c := range slug { + if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { + return false + } + } + + return true +} + +func waitForKeyPress(connection ssh.Channel) { + keyBuf := make([]byte, 1) + for { + _, err := connection.Read(keyBuf) + if err == nil { + break + } + } +} + +var forbiddenSlug = []string{ + "ping", +} + +func isForbiddenSlug(slug string) bool { + for _, s := range forbiddenSlug { + if slug == s { + return true + } + } + return false +} diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go new file mode 100644 index 0000000..2038c2a --- /dev/null +++ b/session/lifecycle/lifecycle.go @@ -0,0 +1,124 @@ +package lifecycle + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "time" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type Interaction interface { + SendMessage(string) +} + +type Forwarder interface { + Close() error + GetTunnelType() types.TunnelType +} + +type Lifecycle struct { + Status types.Status + Conn ssh.Conn + Channel ssh.Channel + + Interaction Interaction + Forwarder Forwarder + SlugManager slug.Manager +} + +type SessionLifecycle interface { + Close() error + WaitForRunningStatus() + SetStatus(status types.Status) + GetConnection() ssh.Conn + GetChannel() ssh.Channel + SetChannel(channel ssh.Channel) +} + +func (l *Lifecycle) GetChannel() ssh.Channel { + return l.Channel +} + +func (l *Lifecycle) SetChannel(channel ssh.Channel) { + l.Channel = channel +} +func (l *Lifecycle) GetConnection() ssh.Conn { + return l.Conn +} +func (l *Lifecycle) SetStatus(status types.Status) { + l.Status = status +} +func (l *Lifecycle) WaitForRunningStatus() { + timeout := time.After(3 * time.Second) + ticker := time.NewTicker(150 * time.Millisecond) + defer ticker.Stop() + frames := []string{"-", "\\", "|", "/"} + i := 0 + for { + select { + case <-ticker.C: + l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) + i = (i + 1) % len(frames) + if l.Status == types.RUNNING { + l.Interaction.SendMessage("\r\033[K") + return + } + case <-timeout: + l.Interaction.SendMessage("\r\033[K") + l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") + err := l.Close() + if err != nil { + log.Printf("failed to close session: %v", err) + } + log.Println("Timeout waiting for session to start running") + return + } + } +} + +func (l *Lifecycle) Close() error { + err := l.Forwarder.Close() + if err != nil { + return err + } + //if s.Forwarder.Listener != nil { + // err := s.Forwarder.Listener.Close() + // if err != nil && !errors.Is(err, net.ErrClosed) { + // return err + // } + //} + + if l.Channel != nil { + err := l.Channel.Close() + if err != nil && !errors.Is(err, io.EOF) { + return err + } + } + + if l.Conn != nil { + err := l.Conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + } + + //clientSlug := l.SlugManager.Get() + //if clientSlug != "" { + // unregisterClient(clientSlug) + //} + + //if l.Forwarder.GetType() == "TCP" && s.Forwarder.Listener != nil { + // err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) + // if err != nil { + // return err + // } + //} + + return nil +} diff --git a/session/session.go b/session/session.go index 2a38c6a..b45b932 100644 --- a/session/session.go +++ b/session/session.go @@ -1,105 +1,82 @@ package session import ( - "bytes" "log" "sync" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) -const ( - INITIALIZING Status = "INITIALIZING" - RUNNING Status = "RUNNING" - SETUP Status = "SETUP" -) - -type TunnelType string - -const ( - HTTP TunnelType = "http" - TCP TunnelType = "tcp" -) - -type SessionLifecycle interface { - Close() error - WaitForRunningStatus() -} - -type SessionCloser interface { - Close() error -} - type Session interface { - SessionLifecycle - InteractionController - ForwardingController -} + lifecycle.Lifecycle + interaction.InteractionController + forwarder.ForwardingController -type Lifecycle struct { - Status Status + HandleGlobalRequest(ch <-chan *ssh.Request) + HandleTCPIPForward(req *ssh.Request) + HandleHTTPForward(req *ssh.Request, port uint16) + HandleTCPForward(req *ssh.Request, addr string, port uint16) } type SSHSession struct { - Lifecycle *Lifecycle - Interaction *Interaction - Forwarder *Forwarder - - Conn *ssh.ServerConn - channel ssh.Channel - - slug string - slugMu sync.RWMutex + Lifecycle lifecycle.SessionLifecycle + Interaction interaction.InteractionController + Forwarder forwarder.ForwardingController + SlugManager slug.Manager } func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { - session := SSHSession{ - Lifecycle: &Lifecycle{ - Status: INITIALIZING, - }, - Interaction: &Interaction{ - CommandBuffer: new(bytes.Buffer), - EditMode: false, - EditSlug: "", - channel: nil, - getSlug: nil, - setSlug: nil, - session: nil, - forwarder: nil, - }, - Forwarder: &Forwarder{ - Listener: nil, - TunnelType: "", - ForwardedPort: 0, - getSlug: nil, - setSlug: nil, - }, - Conn: conn, - channel: nil, - slug: "", + slugManager := slug.NewManager() + forwarderManager := &forwarder.Forwarder{ + Listener: nil, + TunnelType: "", + ForwardedPort: 0, + SlugManager: slugManager, } - - session.Forwarder.getSlug = session.GetSlug - session.Forwarder.setSlug = session.SetSlug - session.Interaction.getSlug = session.GetSlug - session.Interaction.setSlug = session.SetSlug - session.Interaction.session = &session - session.Interaction.forwarder = session.Forwarder + interactionManager := &interaction.Interaction{ + CommandBuffer: nil, + EditMode: false, + EditSlug: "", + SlugManager: slugManager, + Forwarder: forwarderManager, + Lifecycle: nil, + } + lifecycleManager := &lifecycle.Lifecycle{ + Status: "", + Conn: conn, + Channel: nil, + Interaction: interactionManager, + Forwarder: forwarderManager, + SlugManager: slugManager, + } + session := &SSHSession{ + Lifecycle: lifecycleManager, + Interaction: interactionManager, + Forwarder: forwarderManager, + SlugManager: slugManager, + } + interactionManager.SetLifecycle(lifecycleManager) go func() { - go session.waitForRunningStatus() + go session.Lifecycle.WaitForRunningStatus() for channel := range sshChan { ch, reqs, _ := channel.Accept() - if session.channel == nil { - session.channel = ch - session.Interaction.channel = ch - session.Lifecycle.Status = SETUP + if session.Lifecycle.GetChannel() == nil { + session.Lifecycle.SetChannel(ch) + session.Interaction.SetChannel(ch) + //session.Interaction.channel = ch + session.Lifecycle.SetStatus(types.SETUP) go session.HandleGlobalRequest(forwardingReq) } go session.HandleGlobalRequest(reqs) } - err := session.Close() + err := session.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -107,14 +84,26 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan }() } -func (s *SSHSession) GetSlug() string { - s.slugMu.RLock() - defer s.slugMu.RUnlock() - return s.slug +var ( + clientsMutex sync.RWMutex + Clients = make(map[string]*SSHSession) +) + +func registerClient(slug string, session *SSHSession) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[slug]; exists { + return false + } + + Clients[slug] = session + return true } -func (s *SSHSession) SetSlug(slug string) { - s.slugMu.Lock() - s.slug = slug - s.slugMu.Unlock() +func unregisterClient(slug string) { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + delete(Clients, slug) } diff --git a/session/slug/slug.go b/session/slug/slug.go new file mode 100644 index 0000000..4900e22 --- /dev/null +++ b/session/slug/slug.go @@ -0,0 +1,32 @@ +package slug + +import "sync" + +type Manager interface { + Get() string + Set(slug string) +} + +type manager struct { + slug string + slugMu sync.RWMutex +} + +func NewManager() Manager { + return &manager{ + slug: "", + slugMu: sync.RWMutex{}, + } +} + +func (s *manager) Get() string { + s.slugMu.RLock() + defer s.slugMu.RUnlock() + return s.slug +} + +func (s *manager) Set(slug string) { + s.slugMu.Lock() + s.slug = slug + s.slugMu.Unlock() +} diff --git a/types/types.go b/types/types.go new file mode 100644 index 0000000..c007661 --- /dev/null +++ b/types/types.go @@ -0,0 +1,16 @@ +package types + +type Status string + +const ( + INITIALIZING Status = "INITIALIZING" + RUNNING Status = "RUNNING" + SETUP Status = "SETUP" +) + +type TunnelType string + +const ( + HTTP TunnelType = "HTTP" + TCP TunnelType = "TCP" +) From 7a31047bb98b63b7caaa1ab08ec839016f38dd4f Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 4 Dec 2025 22:48:15 +0700 Subject: [PATCH 2/7] refactor: restructure session initialization to avoid circular references --- server/http.go | 6 +- server/middleware.go | 4 +- session/forwarder/forwarder.go | 145 +++++++++++++++++++++++------ session/handler.go | 127 +++---------------------- session/interaction/interaction.go | 80 ++++++++-------- session/lifecycle/lifecycle.go | 42 +++++---- session/session.go | 44 ++++++--- 7 files changed, 229 insertions(+), 219 deletions(-) diff --git a/server/http.go b/server/http.go index d2ad943..a69b836 100644 --- a/server/http.go +++ b/server/http.go @@ -30,13 +30,13 @@ type CustomWriter struct { buf []byte respHeader *ResponseHeaderFactory reqHeader *RequestHeaderFactory - interaction interaction.InteractionController + interaction interaction.Controller respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware } -func (cw *CustomWriter) SetInteraction(interaction interaction.InteractionController) { +func (cw *CustomWriter) SetInteraction(interaction interaction.Controller) { cw.interaction = interaction } @@ -350,7 +350,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } } - sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) + sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) return } diff --git a/server/middleware.go b/server/middleware.go index d5f733b..ad8c546 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -29,11 +29,11 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [ } type RequestLogger struct { - interaction interaction.InteractionController + interaction interaction.Controller remoteAddr net.Addr } -func NewRequestLogger(interaction interaction.InteractionController, remoteAddr net.Addr) *RequestLogger { +func NewRequestLogger(interaction interaction.Controller, remoteAddr net.Addr) *RequestLogger { return &RequestLogger{ interaction: interaction, remoteAddr: remoteAddr, diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 82794ba..41c9602 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -1,9 +1,17 @@ package forwarder import ( + "bytes" + "encoding/binary" + "errors" + "io" + "log" "net" + "strconv" "tunnel_pls/session/slug" "tunnel_pls/types" + + "golang.org/x/crypto/ssh" ) type Forwarder struct { @@ -11,14 +19,83 @@ type Forwarder struct { TunnelType types.TunnelType ForwardedPort uint16 SlugManager slug.Manager + Lifecycle Lifecycle +} + +type Lifecycle interface { + GetConnection() ssh.Conn +} + +type ForwardingController interface { + AcceptTCPConnections() + SetType(tunnelType types.TunnelType) + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 + SetForwardedPort(port uint16) + SetListener(listener net.Listener) + GetListener() net.Listener + Close() error + HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) + SetLifecycle(lifecycle Lifecycle) +} + +func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { + f.Lifecycle = lifecycle } func (f *Forwarder) AcceptTCPConnections() { - panic("implement me") + for { + conn, err := f.GetListener().Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + log.Printf("Error accepting connection: %v", err) + continue + } + originHost, originPort := ParseAddr(conn.RemoteAddr().String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort()) + channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return + } + + go func() { + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } + } + }() + go f.HandleConnection(conn, channel, conn.RemoteAddr()) + } } -func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { - panic("implement me") +func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { + defer func(src ssh.Channel) { + err := src.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing connection: %v", err) + } + }(src) + log.Printf("Handling new forwarded connection from %s", remoteAddr) + + go func() { + _, err := io.Copy(src, dst) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + log.Printf("Error copying from conn.Reader to channel: %v", err) + } + }() + + _, err := io.Copy(dst, src) + + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from channel to conn.Writer: %v", err) + } + return } func (f *Forwarder) SetType(tunnelType types.TunnelType) { @@ -52,33 +129,39 @@ func (f *Forwarder) Close() error { return nil } -type ForwardingController interface { - AcceptTCPConnections() - UpdateClientSlug(oldSlug, newSlug string) bool - SetType(tunnelType types.TunnelType) - GetTunnelType() types.TunnelType - GetForwardedPort() uint16 - SetForwardedPort(port uint16) - SetListener(listener net.Listener) - GetListener() net.Listener - Close() error +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) +} +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) } -//func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { -// session.clientsMutex.Lock() -// defer session.clientsMutex.Unlock() -// -// if _, exists := session.Clients[newSlug]; exists && newSlug != oldSlug { -// return false -// } -// -// client, ok := session.Clients[oldSlug] -// if !ok { -// return false -// } -// -// delete(session.Clients, oldSlug) -// f.SlugManager.Set(newSlug) -// session.Clients[newSlug] = client -// return true -//} +func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { + var buf bytes.Buffer + + writeSSHString(&buf, "localhost") + err := binary.Write(&buf, binary.BigEndian, uint32(port)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + writeSSHString(&buf, host) + err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + + return buf.Bytes() +} diff --git a/session/handler.go b/session/handler.go index c807d65..9123310 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,12 +3,9 @@ package session import ( "bytes" "encoding/binary" - "errors" "fmt" - "io" "log" "net" - "strconv" portUtil "tunnel_pls/internal/port" "tunnel_pls/types" @@ -17,10 +14,7 @@ import ( "golang.org/x/crypto/ssh" ) -type UserConnection struct { - Reader io.Reader - Writer net.Conn -} +var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { for req := range GlobalRequest { @@ -157,23 +151,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { s.HandleTCPForward(req, addr, portToBind) } -var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} - -func isBlockedPort(port uint16) bool { - if port == 80 || port == 443 { - return false - } - if port < 1024 && port != 0 { - return true - } - for _, p := range blockedReservedPorts { - if p == port { - return true - } - } - return false -} - func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { s.Forwarder.SetType(types.HTTP) s.Forwarder.SetForwardedPort(portToBind) @@ -237,7 +214,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.Interaction.ShowWelcomeMessage() s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) - go s.acceptTCPConnections() + go s.Forwarder.AcceptTCPConnections() buf := new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) @@ -253,37 +230,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind } } -func (s *SSHSession) acceptTCPConnections() { - for { - conn, err := s.Forwarder.GetListener().Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - originHost, originPort := ParseAddr(conn.RemoteAddr().String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) - channel, reqs, err := s.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - return - } - - go func() { - for req := range reqs { - err := req.Reply(false, nil) - if err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } - }() - go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr()) - } -} - func generateUniqueSlug() string { maxAttempts := 5 @@ -303,30 +249,6 @@ func generateUniqueSlug() string { return "" } -func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { - defer func(src ssh.Channel) { - err := src.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing connection: %v", err) - } - }(src) - log.Printf("Handling new forwarded connection from %s", remoteAddr) - - go func() { - _, err := io.Copy(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying from conn.Reader to channel: %v", err) - } - }() - - _, err := io.Copy(dst, src) - - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error copying from channel to conn.Writer: %v", err) - } - return -} - func readSSHString(reader *bytes.Reader) (string, error) { var length uint32 if err := binary.Read(reader, binary.BigEndian, &length); err != nil { @@ -339,40 +261,17 @@ func readSSHString(reader *bytes.Reader) (string, error) { return string(strBytes), nil } -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { - var buf bytes.Buffer - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil +func isBlockedPort(port uint16) bool { + if port == 80 || port == 443 { + return false } - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil + if port < 1024 && port != 0 { + return true } - - return buf.Bytes() -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) + for _, p := range blockedReservedPorts { + if p == port { + return true + } + } + return false } diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index d9e65d6..0f6c3ca 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -14,22 +14,27 @@ import ( "golang.org/x/crypto/ssh" ) +var forbiddenSlug = []string{ + "ping", +} + type Lifecycle interface { Close() error } -type InteractionController interface { +type Controller interface { SendMessage(message string) HandleUserInput() - HandleCommand(command string, commandBuffer *bytes.Buffer) - HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) + HandleCommand(command string) + HandleSlugEditMode(connection ssh.Channel, char byte) HandleSlugSave(conn ssh.Channel) - HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) + HandleSlugCancel(connection ssh.Channel) HandleSlugUpdateError() ShowWelcomeMessage() DisplaySlugEditor() SetChannel(channel ssh.Channel) SetLifecycle(lifecycle Lifecycle) + SetSlugModificator(func(oldSlug, newSlug string) bool) } type Forwarder interface { @@ -39,13 +44,14 @@ type Forwarder interface { } type Interaction struct { - CommandBuffer *bytes.Buffer - EditMode bool - EditSlug string - channel ssh.Channel - SlugManager slug.Manager - Forwarder Forwarder - Lifecycle Lifecycle + CommandBuffer *bytes.Buffer + EditMode bool + EditSlug string + channel ssh.Channel + SlugManager slug.Manager + Forwarder Forwarder + Lifecycle Lifecycle + updateClientSlug func(oldSlug, newSlug string) bool } func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { @@ -67,7 +73,6 @@ func (i *Interaction) SendMessage(message string) { } func (i *Interaction) HandleUserInput() { - var commandBuffer bytes.Buffer buf := make([]byte, 1) i.EditMode = false @@ -84,42 +89,42 @@ func (i *Interaction) HandleUserInput() { char := buf[0] if i.EditMode { - i.HandleSlugEditMode(i.channel, char, &commandBuffer) + i.HandleSlugEditMode(i.channel, char) continue } i.SendMessage(string(buf[:n])) if char == 8 || char == 127 { - if commandBuffer.Len() > 0 { - commandBuffer.Truncate(commandBuffer.Len() - 1) + if i.CommandBuffer.Len() > 0 { + i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) i.SendMessage("\b \b") } continue } if char == '/' { - commandBuffer.Reset() - commandBuffer.WriteByte(char) + i.CommandBuffer.Reset() + i.CommandBuffer.WriteByte(char) continue } - if commandBuffer.Len() > 0 { + if i.CommandBuffer.Len() > 0 { if char == 13 { - i.HandleCommand(commandBuffer.String(), &commandBuffer) + i.HandleCommand(i.CommandBuffer.String()) continue } - commandBuffer.WriteByte(char) + i.CommandBuffer.WriteByte(char) } } } } -func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) { if char == 13 { i.HandleSlugSave(connection) } else if char == 27 { - i.HandleSlugCancel(connection, commandBuffer) + i.HandleSlugCancel(connection) } else if char == 8 || char == 127 { if len(i.EditSlug) > 0 { i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1] @@ -160,13 +165,13 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { return } if isValid { - //oldSlug := i.SlugManager.Get() + oldSlug := i.SlugManager.Get() newSlug := i.EditSlug - //if !i.updateClientSlug(oldSlug, newSlug) { - // i.HandleSlugUpdateError() - // return - //} + if !i.updateClientSlug(oldSlug, newSlug) { + i.HandleSlugUpdateError() + return + } _, err := connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) if err != nil { @@ -251,7 +256,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { i.CommandBuffer.Reset() } -func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleSlugCancel(connection ssh.Channel) { i.EditMode = false _, err := connection.Write([]byte("\033[H\033[2J")) if err != nil { @@ -278,7 +283,7 @@ func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *by } i.ShowWelcomeMessage() - commandBuffer.Reset() + i.CommandBuffer.Reset() } func (i *Interaction) HandleSlugUpdateError() { @@ -296,7 +301,7 @@ func (i *Interaction) HandleSlugUpdateError() { } } -func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleCommand(command string) { switch command { case "/bye": i.SendMessage("\r\nClosing connection...") @@ -307,7 +312,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) } return case "/help": - i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug") + i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug\r\n") case "/clear": i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() @@ -323,7 +328,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) } case "/slug": if i.Forwarder.GetTunnelType() != types.HTTP { - i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))) + i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType())) } else { i.EditMode = true i.EditSlug = i.SlugManager.Get() @@ -335,7 +340,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) i.SendMessage("Unknown command") } - commandBuffer.Reset() + i.CommandBuffer.Reset() } func (i *Interaction) ShowWelcomeMessage() { @@ -401,6 +406,10 @@ func (i *Interaction) DisplaySlugEditor() { i.SendMessage("\r\n\r\n") } +func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) bool) { + i.updateClientSlug = modificator +} + func centerText(text string, width int) string { padding := (width - len(text)) / 2 if padding < 0 { @@ -408,6 +417,7 @@ func centerText(text string, width int) string { } return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding) } + func isValidSlug(slug string) bool { if len(slug) < 3 || len(slug) > 20 { return false @@ -436,10 +446,6 @@ func waitForKeyPress(connection ssh.Channel) { } } -var forbiddenSlug = []string{ - "ping", -} - func isForbiddenSlug(slug string) bool { for _, s := range forbiddenSlug { if slug == s { diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 2038c2a..29b02ed 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -7,6 +7,7 @@ import ( "log" "net" "time" + portUtil "tunnel_pls/internal/port" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -20,6 +21,7 @@ type Interaction interface { type Forwarder interface { Close() error GetTunnelType() types.TunnelType + GetForwardedPort() uint16 } type Lifecycle struct { @@ -27,9 +29,14 @@ type Lifecycle struct { Conn ssh.Conn Channel ssh.Channel - Interaction Interaction - Forwarder Forwarder - SlugManager slug.Manager + Interaction Interaction + Forwarder Forwarder + SlugManager slug.Manager + unregisterClient func(slug string) +} + +func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { + l.unregisterClient = unregisterClient } type SessionLifecycle interface { @@ -39,6 +46,7 @@ type SessionLifecycle interface { GetConnection() ssh.Conn GetChannel() ssh.Channel SetChannel(channel ssh.Channel) + SetUnregisterClient(unregisterClient func(slug string)) } func (l *Lifecycle) GetChannel() ssh.Channel { @@ -84,15 +92,9 @@ func (l *Lifecycle) WaitForRunningStatus() { func (l *Lifecycle) Close() error { err := l.Forwarder.Close() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { return err } - //if s.Forwarder.Listener != nil { - // err := s.Forwarder.Listener.Close() - // if err != nil && !errors.Is(err, net.ErrClosed) { - // return err - // } - //} if l.Channel != nil { err := l.Channel.Close() @@ -108,17 +110,17 @@ func (l *Lifecycle) Close() error { } } - //clientSlug := l.SlugManager.Get() - //if clientSlug != "" { - // unregisterClient(clientSlug) - //} + clientSlug := l.SlugManager.Get() + if clientSlug != "" { + l.unregisterClient(clientSlug) + } - //if l.Forwarder.GetType() == "TCP" && s.Forwarder.Listener != nil { - // err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) - // if err != nil { - // return err - // } - //} + if l.Forwarder.GetTunnelType() == types.TCP { + err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false) + if err != nil { + return err + } + } return nil } diff --git a/session/session.go b/session/session.go index b45b932..e122e38 100644 --- a/session/session.go +++ b/session/session.go @@ -1,6 +1,7 @@ package session import ( + "bytes" "log" "sync" "tunnel_pls/session/forwarder" @@ -12,11 +13,12 @@ import ( "golang.org/x/crypto/ssh" ) -type Session interface { - lifecycle.Lifecycle - interaction.InteractionController - forwarder.ForwardingController +var ( + clientsMutex sync.RWMutex + Clients = make(map[string]*SSHSession) +) +type Session interface { HandleGlobalRequest(ch <-chan *ssh.Request) HandleTCPIPForward(req *ssh.Request) HandleHTTPForward(req *ssh.Request, port uint16) @@ -25,7 +27,7 @@ type Session interface { type SSHSession struct { Lifecycle lifecycle.SessionLifecycle - Interaction interaction.InteractionController + Interaction interaction.Controller Forwarder forwarder.ForwardingController SlugManager slug.Manager } @@ -39,7 +41,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan SlugManager: slugManager, } interactionManager := &interaction.Interaction{ - CommandBuffer: nil, + CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), EditMode: false, EditSlug: "", SlugManager: slugManager, @@ -54,13 +56,18 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan Forwarder: forwarderManager, SlugManager: slugManager, } + + interactionManager.SetLifecycle(lifecycleManager) + interactionManager.SetSlugModificator(updateClientSlug) + forwarderManager.SetLifecycle(lifecycleManager) + lifecycleManager.SetUnregisterClient(unregisterClient) + session := &SSHSession{ Lifecycle: lifecycleManager, Interaction: interactionManager, Forwarder: forwarderManager, SlugManager: slugManager, } - interactionManager.SetLifecycle(lifecycleManager) go func() { go session.Lifecycle.WaitForRunningStatus() @@ -70,7 +77,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan if session.Lifecycle.GetChannel() == nil { session.Lifecycle.SetChannel(ch) session.Interaction.SetChannel(ch) - //session.Interaction.channel = ch session.Lifecycle.SetStatus(types.SETUP) go session.HandleGlobalRequest(forwardingReq) } @@ -84,10 +90,24 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan }() } -var ( - clientsMutex sync.RWMutex - Clients = make(map[string]*SSHSession) -) +func updateClientSlug(oldSlug, newSlug string) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { + return false + } + + client, ok := Clients[oldSlug] + if !ok { + return false + } + + delete(Clients, oldSlug) + client.SlugManager.Set(newSlug) + Clients[newSlug] = client + return true +} func registerClient(slug string, session *SSHSession) bool { clientsMutex.Lock() From 659f6c3ee7f4b56812bb2d364757ae3a9d0222a5 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 5 Dec 2025 13:49:33 +0700 Subject: [PATCH 3/7] refactor: move CreateForwardedTCPIPPayload to forwarder interface --- go.mod | 9 ++---- go.sum | 13 ++++---- server/http.go | 3 +- server/server.go | 41 ------------------------- session/forwarder/forwarder.go | 48 ++++++++++++++++-------------- session/handler.go | 2 +- session/interaction/interaction.go | 2 +- 7 files changed, 36 insertions(+), 82 deletions(-) diff --git a/go.mod b/go.mod index 31fdc54..09be3c3 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,8 @@ module tunnel_pls go 1.24.4 require ( - github.com/a-h/templ v0.3.833 github.com/joho/godotenv v1.5.1 - golang.org/x/crypto v0.32.0 - golang.org/x/net v0.33.0 + golang.org/x/crypto v0.45.0 ) -require ( - github.com/gorilla/websocket v1.5.3 // indirect - golang.org/x/sys v0.29.0 // indirect -) +require golang.org/x/sys v0.38.0 // indirect diff --git a/go.sum b/go.sum index e14b727..27269bf 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,13 @@ -github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU= -github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= diff --git a/server/http.go b/server/http.go index a69b836..3d9ac0f 100644 --- a/server/http.go +++ b/server/http.go @@ -298,8 +298,7 @@ func Handler(conn net.Conn) { } func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { - originHost, originPort := ParseAddr(cw.RemoteAddr.String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) + payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) diff --git a/server/server.go b/server/server.go index 0e6bdb6..9d01817 100644 --- a/server/server.go +++ b/server/server.go @@ -1,13 +1,10 @@ package server import ( - "bytes" - "encoding/binary" "fmt" "log" "net" "net/http" - "strconv" "tunnel_pls/utils" "golang.org/x/crypto/ssh" @@ -58,41 +55,3 @@ func (s *Server) Start() { go s.handleConnection(conn) } } - -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { - var buf bytes.Buffer - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - - return buf.Bytes() -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 41c9602..450184d 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -37,6 +37,7 @@ type ForwardingController interface { Close() error HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) SetLifecycle(lifecycle Lifecycle) + CreateForwardedTCPIPPayload(origin net.Addr) []byte } func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { @@ -53,8 +54,7 @@ func (f *Forwarder) AcceptTCPConnections() { log.Printf("Error accepting connection: %v", err) continue } - originHost, originPort := ParseAddr(conn.RemoteAddr().String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort()) + payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) @@ -129,33 +129,18 @@ func (f *Forwarder) Close() error { return nil } -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) -} -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { +func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { var buf bytes.Buffer + host, originPort := parseAddr(origin.String()) + writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) + err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort())) if err != nil { log.Printf("Failed to write string to buffer: %v", err) return nil } + writeSSHString(&buf, host) err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) if err != nil { @@ -165,3 +150,22 @@ func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { return buf.Bytes() } + +func parseAddr(addr string) (string, uint16) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint16(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint16(port) +} + +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) +} diff --git a/session/handler.go b/session/handler.go index 9123310..e2a77f7 100644 --- a/session/handler.go +++ b/session/handler.go @@ -212,7 +212,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.Forwarder.SetListener(listener) s.Forwarder.SetForwardedPort(portToBind) s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) go s.Forwarder.AcceptTCPConnections() diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 0f6c3ca..3f3db3f 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -324,7 +324,7 @@ func (i *Interaction) HandleCommand(command string) { } i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) } else { - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.Forwarder.GetTunnelType(), domain, i.Forwarder.GetForwardedPort())) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) } case "/slug": if i.Forwarder.GetTunnelType() != types.HTTP { From 990bccbff7b271e7c95de001741e07ac26ed5091 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 5 Dec 2025 22:24:46 +0700 Subject: [PATCH 4/7] update: handle message deletion properly --- session/interaction/interaction.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 3f3db3f..181b3a4 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -44,6 +44,7 @@ type Forwarder interface { } type Interaction struct { + InputLength int CommandBuffer *bytes.Buffer EditMode bool EditSlug string @@ -96,13 +97,18 @@ func (i *Interaction) HandleUserInput() { i.SendMessage(string(buf[:n])) if char == 8 || char == 127 { + if i.InputLength > 0 { + //i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) + i.SendMessage("\b \b") + } if i.CommandBuffer.Len() > 0 { i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) - i.SendMessage("\b \b") } continue } + i.InputLength += n + if char == '/' { i.CommandBuffer.Reset() i.CommandBuffer.WriteByte(char) @@ -111,6 +117,7 @@ func (i *Interaction) HandleUserInput() { if i.CommandBuffer.Len() > 0 { if char == 13 { + i.SendMessage("\033[K") i.HandleCommand(i.CommandBuffer.String()) continue } From af951b8fa7b3a04698da2cbe025ed5e39eb6ce29 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 5 Dec 2025 22:26:38 +0700 Subject: [PATCH 5/7] fix: discard unused buffers in the ssh channel before disconnecting --- server/http.go | 32 +++--------------------------- session/forwarder/forwarder.go | 16 ++++++++++++++- session/interaction/interaction.go | 1 - types/types.go | 5 +++++ 4 files changed, 23 insertions(+), 31 deletions(-) diff --git a/server/http.go b/server/http.go index 3d9ac0f..18bbf38 100644 --- a/server/http.go +++ b/server/http.go @@ -12,16 +12,10 @@ import ( "strings" "tunnel_pls/session" "tunnel_pls/session/interaction" + "tunnel_pls/types" "tunnel_pls/utils" - - "golang.org/x/crypto/ssh" ) -var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + - "Content-Length: 11\r\n" + - "Content-Type: text/plain\r\n\r\n" + - "Bad Gateway") - type CustomWriter struct { RemoteAddr net.Addr writer io.Writer @@ -130,7 +124,7 @@ func isHTTPHeader(buf []byte) bool { } func (cw *CustomWriter) Write(p []byte) (int, error) { - if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) { + if len(p) == len(types.BadGatewayResponse) && bytes.Equal(p, types.BadGatewayResponse) { return cw.writer.Write(p) } @@ -216,7 +210,7 @@ func NewHTTPServer() error { func Handler(conn net.Conn) { defer func() { err := conn.Close() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { log.Printf("Error closing connection: %v", err) return } @@ -302,20 +296,8 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) - sendBadGatewayResponse(cw) return } - defer func(channel ssh.Channel) { - err := channel.Close() - if err != nil { - if errors.Is(err, io.EOF) { - sendBadGatewayResponse(cw) - return - } - log.Println("Failed to close connection:", err) - return - } - }(channel) go func() { for req := range reqs { @@ -352,11 +334,3 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) return } - -func sendBadGatewayResponse(writer io.Writer) { - _, err := writer.Write(BadGatewayResponse) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return - } -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 450184d..3d846e6 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -38,6 +38,7 @@ type ForwardingController interface { HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) SetLifecycle(lifecycle Lifecycle) CreateForwardedTCPIPPayload(origin net.Addr) []byte + WriteBadGatewayResponse(dst io.Writer) } func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { @@ -76,7 +77,12 @@ func (f *Forwarder) AcceptTCPConnections() { func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { defer func(src ssh.Channel) { - err := src.Close() + _, err := io.Copy(io.Discard, src) + if err != nil { + log.Printf("Failed to discard connection: %v", err) + } + + err = src.Close() if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error closing connection: %v", err) } @@ -122,6 +128,14 @@ func (f *Forwarder) GetListener() net.Listener { return f.Listener } +func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { + _, err := dst.Write(types.BadGatewayResponse) + if err != nil { + log.Printf("failed to write Bad Gateway response: %v", err) + return + } +} + func (f *Forwarder) Close() error { if f.GetTunnelType() != types.HTTP { return f.Listener.Close() diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 181b3a4..0c998c4 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -98,7 +98,6 @@ func (i *Interaction) HandleUserInput() { if char == 8 || char == 127 { if i.InputLength > 0 { - //i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) i.SendMessage("\b \b") } if i.CommandBuffer.Len() > 0 { diff --git a/types/types.go b/types/types.go index c007661..f909da5 100644 --- a/types/types.go +++ b/types/types.go @@ -14,3 +14,8 @@ const ( HTTP TunnelType = "HTTP" TCP TunnelType = "TCP" ) + +var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + + "Content-Length: 11\r\n" + + "Content-Type: text/plain\r\n\r\n" + + "Bad Gateway") From 368cc0b3e30a0c39542f9c91721f02e4a4e7ff53 Mon Sep 17 00:00:00 2001 From: bagas Date: Sat, 6 Dec 2025 00:01:25 +0700 Subject: [PATCH 6/7] fix: resolve nil pointer dereference in interaction on TLS request --- docker-compose.yaml | 23 +++++++++++++++++++++++ server/http.go | 10 ++++++---- server/https.go | 2 +- server/middleware.go | 7 ++++--- 4 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 docker-compose.yaml diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..a626556 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,23 @@ +services: + tunnlpls: + image: git.fossy.my.id/bagas/tunnl_please:staging + ports: + - 80:80 + - 2200:2200 + volumes: + - ./certs:/certs +# - /etc/letsencrypt/live/sgp.tunnl.live/fullchain.pem:/certs/fullchain.pem +# - /etc/letsencrypt/live/sgp.tunnl.live/privkey.pem:/certs/privkey.pem + labels: + - "com.centurylinklabs.watchtower.enable=true" + environment: + domain: sgp.tunnl.live + port: 2200 + tls_enabled: true + tls_redirect: true + cert_loc: /certs/localhost.direct.SS.crt + key_loc: /certs/localhost.direct.SS.key + ssh_private_key: /certs/id_rsa + cors_list: https://tunnl.live + ALLOWED_PORTS: 10000-50000 + restart: always diff --git a/server/http.go b/server/http.go index 18bbf38..0960932 100644 --- a/server/http.go +++ b/server/http.go @@ -11,11 +11,13 @@ import ( "regexp" "strings" "tunnel_pls/session" - "tunnel_pls/session/interaction" "tunnel_pls/types" "tunnel_pls/utils" ) +type Interaction interface { + SendMessage(message string) +} type CustomWriter struct { RemoteAddr net.Addr writer io.Writer @@ -24,13 +26,13 @@ type CustomWriter struct { buf []byte respHeader *ResponseHeaderFactory reqHeader *RequestHeaderFactory - interaction interaction.Controller + interaction Interaction respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware } -func (cw *CustomWriter) SetInteraction(interaction interaction.Controller) { +func (cw *CustomWriter) SetInteraction(interaction Interaction) { cw.interaction = interaction } @@ -176,7 +178,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return n, nil } -func (cw *CustomWriter) AddInteraction(interaction *interaction.Interaction) { +func (cw *CustomWriter) AddInteraction(interaction Interaction) { cw.interaction = interaction } diff --git a/server/https.go b/server/https.go index f4ecf99..cbe7c86 100644 --- a/server/https.go +++ b/server/https.go @@ -112,7 +112,7 @@ func HandlerTLS(conn net.Conn) { return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - + cw.SetInteraction(sshSession.Interaction) forwardRequest(cw, reqhf, sshSession) return } diff --git a/server/middleware.go b/server/middleware.go index ad8c546..a28bdab 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "time" - "tunnel_pls/session/interaction" ) type RequestMiddleware interface { @@ -29,20 +28,22 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [ } type RequestLogger struct { - interaction interaction.Controller + interaction Interaction remoteAddr net.Addr } -func NewRequestLogger(interaction interaction.Controller, remoteAddr net.Addr) *RequestLogger { +func NewRequestLogger(interaction Interaction, remoteAddr net.Addr) *RequestLogger { return &RequestLogger{ interaction: interaction, remoteAddr: remoteAddr, } } + func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error { rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path)) return nil } + func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil } //TODO: Implement caching atau enggak From 43178d51b503f559988706f47519cc51ae80d435 Mon Sep 17 00:00:00 2001 From: bagas Date: Sat, 6 Dec 2025 00:03:19 +0700 Subject: [PATCH 7/7] refactor: remove docker compose --- docker-compose.yaml | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 docker-compose.yaml diff --git a/docker-compose.yaml b/docker-compose.yaml deleted file mode 100644 index a626556..0000000 --- a/docker-compose.yaml +++ /dev/null @@ -1,23 +0,0 @@ -services: - tunnlpls: - image: git.fossy.my.id/bagas/tunnl_please:staging - ports: - - 80:80 - - 2200:2200 - volumes: - - ./certs:/certs -# - /etc/letsencrypt/live/sgp.tunnl.live/fullchain.pem:/certs/fullchain.pem -# - /etc/letsencrypt/live/sgp.tunnl.live/privkey.pem:/certs/privkey.pem - labels: - - "com.centurylinklabs.watchtower.enable=true" - environment: - domain: sgp.tunnl.live - port: 2200 - tls_enabled: true - tls_redirect: true - cert_loc: /certs/localhost.direct.SS.crt - key_loc: /certs/localhost.direct.SS.key - ssh_private_key: /certs/id_rsa - cors_list: https://tunnl.live - ALLOWED_PORTS: 10000-50000 - restart: always