diff --git a/server/http.go b/server/http.go index 3f1aaba..d2ad943 100644 --- a/server/http.go +++ b/server/http.go @@ -11,6 +11,7 @@ import ( "regexp" "strings" "tunnel_pls/session" + "tunnel_pls/session/interaction" "tunnel_pls/utils" "golang.org/x/crypto/ssh" @@ -29,12 +30,16 @@ type CustomWriter struct { buf []byte respHeader *ResponseHeaderFactory reqHeader *RequestHeaderFactory - interaction *session.Interaction + interaction interaction.InteractionController respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware } +func (cw *CustomWriter) SetInteraction(interaction interaction.InteractionController) { + cw.interaction = interaction +} + func (cw *CustomWriter) Read(p []byte) (int, error) { tmp := make([]byte, len(p)) read, err := cw.reader.Read(tmp) @@ -177,7 +182,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return n, nil } -func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { +func (cw *CustomWriter) AddInteraction(interaction *interaction.Interaction) { cw.interaction = interaction } @@ -287,16 +292,15 @@ func Handler(conn net.Conn) { return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - + cw.SetInteraction(sshSession.Interaction) forwardRequest(cw, reqhf, sshSession) return } func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { - cw.AddInteraction(sshSession.Interaction) originHost, originPort := ParseAddr(cw.RemoteAddr.String()) payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) - channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) sendBadGatewayResponse(cw) diff --git a/server/middleware.go b/server/middleware.go index 08ee035..d5f733b 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -4,7 +4,7 @@ import ( "fmt" "net" "time" - "tunnel_pls/session" + "tunnel_pls/session/interaction" ) type RequestMiddleware interface { @@ -29,13 +29,13 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [ } type RequestLogger struct { - interaction session.Interaction + interaction interaction.InteractionController remoteAddr net.Addr } -func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger { +func NewRequestLogger(interaction interaction.InteractionController, remoteAddr net.Addr) *RequestLogger { return &RequestLogger{ - interaction: *interaction, + interaction: interaction, remoteAddr: remoteAddr, } } diff --git a/session/forwarder.go b/session/forwarder.go deleted file mode 100644 index e7abc17..0000000 --- a/session/forwarder.go +++ /dev/null @@ -1,37 +0,0 @@ -package session - -import ( - "net" - - "golang.org/x/crypto/ssh" -) - -type Forwarder struct { - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - - getSlug func() string - setSlug func(string) -} - -type ForwardingController interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) - AcceptTCPConnections() -} - -type ForwarderInfo interface { - GetTunnelType() TunnelType - GetForwardedPort() uint16 -} - -func (f *Forwarder) GetTunnelType() TunnelType { - return f.TunnelType -} - -func (f *Forwarder) GetForwardedPort() uint16 { - return f.ForwardedPort -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go new file mode 100644 index 0000000..82794ba --- /dev/null +++ b/session/forwarder/forwarder.go @@ -0,0 +1,84 @@ +package forwarder + +import ( + "net" + "tunnel_pls/session/slug" + "tunnel_pls/types" +) + +type Forwarder struct { + Listener net.Listener + TunnelType types.TunnelType + ForwardedPort uint16 + SlugManager slug.Manager +} + +func (f *Forwarder) AcceptTCPConnections() { + panic("implement me") +} + +func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { + panic("implement me") +} + +func (f *Forwarder) SetType(tunnelType types.TunnelType) { + f.TunnelType = tunnelType +} + +func (f *Forwarder) GetTunnelType() types.TunnelType { + return f.TunnelType +} + +func (f *Forwarder) GetForwardedPort() uint16 { + return f.ForwardedPort +} + +func (f *Forwarder) SetForwardedPort(port uint16) { + f.ForwardedPort = port +} + +func (f *Forwarder) SetListener(listener net.Listener) { + f.Listener = listener +} + +func (f *Forwarder) GetListener() net.Listener { + return f.Listener +} + +func (f *Forwarder) Close() error { + if f.GetTunnelType() != types.HTTP { + return f.Listener.Close() + } + return nil +} + +type ForwardingController interface { + AcceptTCPConnections() + UpdateClientSlug(oldSlug, newSlug string) bool + SetType(tunnelType types.TunnelType) + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 + SetForwardedPort(port uint16) + SetListener(listener net.Listener) + GetListener() net.Listener + Close() error +} + +//func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { +// session.clientsMutex.Lock() +// defer session.clientsMutex.Unlock() +// +// if _, exists := session.Clients[newSlug]; exists && newSlug != oldSlug { +// return false +// } +// +// client, ok := session.Clients[oldSlug] +// if !ok { +// return false +// } +// +// delete(session.Clients, oldSlug) +// f.SlugManager.Set(newSlug) +// session.Clients[newSlug] = client +// return true +//} diff --git a/session/handler.go b/session/handler.go index 5c63338..c807d65 100644 --- a/session/handler.go +++ b/session/handler.go @@ -9,92 +9,24 @@ import ( "log" "net" "strconv" - "sync" - "time" portUtil "tunnel_pls/internal/port" + "tunnel_pls/types" "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) -type Status string - -var forbiddenSlug = []string{ - "ping", -} - type UserConnection struct { Reader io.Reader Writer net.Conn } -var ( - clientsMutex sync.RWMutex - Clients = make(map[string]*SSHSession) -) - -func registerClient(slug string, session *SSHSession) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[slug]; exists { - return false - } - - Clients[slug] = session - return true -} - -func unregisterClient(slug string) { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - delete(Clients, slug) -} - -func (s *SSHSession) Close() error { - if s.Forwarder.Listener != nil { - err := s.Forwarder.Listener.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err - } - } - - if s.channel != nil { - err := s.channel.Close() - if err != nil && !errors.Is(err, io.EOF) { - return err - } - } - - if s.Conn != nil { - err := s.Conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err - } - } - - slug := s.Forwarder.getSlug() - if slug != "" { - unregisterClient(slug) - } - - if s.Forwarder.TunnelType == TCP && s.Forwarder.Listener != nil { - err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) - if err != nil { - return err - } - } - - return nil -} - func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { for req := range GlobalRequest { switch req.Type { case "tcpip-forward": - s.handleTCPIPForward(req) + s.HandleTCPIPForward(req) return case "shell", "pty-req", "window-change": err := req.Reply(true, nil) @@ -113,7 +45,7 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { } } -func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { +func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { log.Println("Port forwarding request detected") reader := bytes.NewReader(req.Payload) @@ -126,7 +58,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -142,7 +74,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -156,7 +88,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -172,7 +104,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -180,11 +112,11 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { } s.Interaction.SendMessage("\033[H\033[2J") - s.Lifecycle.Status = RUNNING + s.Lifecycle.SetStatus(types.RUNNING) go s.Interaction.HandleUserInput() if portToBind == 80 || portToBind == 443 { - s.handleHTTPForward(req, portToBind) + s.HandleHTTPForward(req, portToBind) return } else { if portToBind == 0 { @@ -197,7 +129,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -210,7 +142,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -222,7 +154,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } } - s.handleTCPForward(req, addr, portToBind) + s.HandleTCPForward(req, addr, portToBind) } var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} @@ -242,9 +174,9 @@ func isBlockedPort(port uint16) bool { return false } -func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { - s.Forwarder.TunnelType = HTTP - s.Forwarder.ForwardedPort = portToBind +func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { + s.Forwarder.SetType(types.HTTP) + s.Forwarder.SetForwardedPort(portToBind) slug := generateUniqueSlug() if slug == "" { @@ -256,7 +188,7 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.Forwarder.setSlug(slug) + s.SlugManager.Set(slug) registerClient(slug, s) buf := new(bytes.Buffer) @@ -282,8 +214,8 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { } } -func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - s.Forwarder.TunnelType = TCP +func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { + s.Forwarder.SetType(types.TCP) log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) @@ -294,16 +226,16 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } return } - s.Forwarder.Listener = listener - s.Forwarder.ForwardedPort = portToBind + s.Forwarder.SetListener(listener) + s.Forwarder.SetForwardedPort(portToBind) s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.TunnelType, utils.Getenv("domain"), s.Forwarder.ForwardedPort)) + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) go s.acceptTCPConnections() @@ -323,7 +255,7 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind func (s *SSHSession) acceptTCPConnections() { for { - conn, err := s.Forwarder.Listener.Accept() + conn, err := s.Forwarder.GetListener().Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -333,7 +265,7 @@ func (s *SSHSession) acceptTCPConnections() { } originHost, originPort := ParseAddr(conn.RemoteAddr().String()) payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) - channel, reqs, err := s.Conn.OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := s.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) return @@ -371,71 +303,6 @@ func generateUniqueSlug() string { return "" } -func (s *SSHSession) waitForRunningStatus() { - timeout := time.After(3 * time.Second) - ticker := time.NewTicker(150 * time.Millisecond) - defer ticker.Stop() - frames := []string{"-", "\\", "|", "/"} - i := 0 - for { - select { - case <-ticker.C: - s.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) - i = (i + 1) % len(frames) - if s.Lifecycle.Status == RUNNING { - s.Interaction.SendMessage("\r\033[K") - return - } - case <-timeout: - s.Interaction.SendMessage("\r\033[K") - s.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") - err := s.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } - log.Println("Timeout waiting for session to start running") - return - } - } -} - -func isForbiddenSlug(slug string) bool { - for _, s := range forbiddenSlug { - if slug == s { - return true - } - } - return false -} - -func isValidSlug(slug string) bool { - if len(slug) < 3 || len(slug) > 20 { - return false - } - - if slug[0] == '-' || slug[len(slug)-1] == '-' { - return false - } - - for _, c := range slug { - if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { - return false - } - } - - return true -} - -func waitForKeyPress(connection ssh.Channel) { - keyBuf := make([]byte, 1) - for { - _, err := connection.Read(keyBuf) - if err == nil { - break - } - } -} - func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { defer func(src ssh.Channel) { err := src.Close() diff --git a/session/interaction.go b/session/interaction/interaction.go similarity index 83% rename from session/interaction.go rename to session/interaction/interaction.go index cfa1ce1..d9e65d6 100644 --- a/session/interaction.go +++ b/session/interaction/interaction.go @@ -1,4 +1,4 @@ -package session +package interaction import ( "bytes" @@ -7,21 +7,35 @@ import ( "log" "strings" "time" + "tunnel_pls/session/slug" + "tunnel_pls/types" "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) +type Lifecycle interface { + Close() error +} + type InteractionController interface { SendMessage(message string) HandleUserInput() - HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) - HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) + HandleCommand(command string, commandBuffer *bytes.Buffer) + HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) + HandleSlugSave(conn ssh.Channel) + HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) HandleSlugUpdateError() ShowWelcomeMessage() DisplaySlugEditor() + SetChannel(channel ssh.Channel) + SetLifecycle(lifecycle Lifecycle) +} + +type Forwarder interface { + Close() error + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 } type Interaction struct { @@ -29,13 +43,17 @@ type Interaction struct { EditMode bool EditSlug string channel ssh.Channel + SlugManager slug.Manager + Forwarder Forwarder + Lifecycle Lifecycle +} - getSlug func() string - setSlug func(string) +func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { + i.Lifecycle = lifecycle +} - session SessionCloser - - forwarder ForwarderInfo +func (i *Interaction) SetChannel(channel ssh.Channel) { + i.channel = channel } func (i *Interaction) SendMessage(message string) { @@ -142,13 +160,13 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { return } if isValid { - oldSlug := i.getSlug() + //oldSlug := i.SlugManager.Get() newSlug := i.EditSlug - if !updateClientSlug(oldSlug, newSlug) { - i.HandleSlugUpdateError() - return - } + //if !i.updateClientSlug(oldSlug, newSlug) { + // i.HandleSlugUpdateError() + // return + //} _, err := connection.Write([]byte("\r\n\r\nāœ… SUBDOMAIN UPDATED āœ…\r\n\r\n")) if err != nil { @@ -223,7 +241,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { if utils.Getenv("tls_enabled") == "true" { protocol = "https" } - _, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain))) + _, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))) if err != nil { log.Printf("failed to write to channel: %v", err) return @@ -271,7 +289,7 @@ func (i *Interaction) HandleSlugUpdateError() { i.SendMessage(fmt.Sprintf("Disconnecting in %d...\r\n", iter)) time.Sleep(1 * time.Second) } - err := i.session.Close() + err := i.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) return @@ -282,7 +300,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) switch command { case "/bye": i.SendMessage("\r\nClosing connection...") - err := i.session.Close() + err := i.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) return @@ -294,21 +312,21 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() domain := utils.Getenv("domain") - if i.forwarder.GetTunnelType() == HTTP { + if i.Forwarder.GetTunnelType() == types.HTTP { protocol := "http" if utils.Getenv("tls_enabled") == "true" { protocol = "https" } - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain)) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) } else { - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.forwarder.GetTunnelType(), domain, i.forwarder.GetForwardedPort())) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.Forwarder.GetTunnelType(), domain, i.Forwarder.GetForwardedPort())) } case "/slug": - if i.forwarder.GetTunnelType() != HTTP { - i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.forwarder.GetTunnelType()))) + if i.Forwarder.GetTunnelType() != types.HTTP { + i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))) } else { i.EditMode = true - i.EditSlug = i.getSlug() + i.EditSlug = i.SlugManager.Get() i.SendMessage("\033[H\033[2J") i.DisplaySlugEditor() i.SendMessage("āž¤ " + i.EditSlug + "." + utils.Getenv("domain")) @@ -347,7 +365,7 @@ func (i *Interaction) ShowWelcomeMessage() { func (i *Interaction) DisplaySlugEditor() { domain := utils.Getenv("domain") - fullDomain := i.getSlug() + "." + domain + fullDomain := i.SlugManager.Get() + "." + domain const paddingRight = 4 @@ -383,25 +401,6 @@ func (i *Interaction) DisplaySlugEditor() { i.SendMessage("\r\n\r\n") } -func updateClientSlug(oldSlug, newSlug string) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { - return false - } - - client, ok := Clients[oldSlug] - if !ok { - return false - } - - delete(Clients, oldSlug) - client.Forwarder.setSlug(newSlug) - Clients[newSlug] = client - return true -} - func centerText(text string, width int) string { padding := (width - len(text)) / 2 if padding < 0 { @@ -409,3 +408,43 @@ func centerText(text string, width int) string { } return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding) } +func isValidSlug(slug string) bool { + if len(slug) < 3 || len(slug) > 20 { + return false + } + + if slug[0] == '-' || slug[len(slug)-1] == '-' { + return false + } + + for _, c := range slug { + if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { + return false + } + } + + return true +} + +func waitForKeyPress(connection ssh.Channel) { + keyBuf := make([]byte, 1) + for { + _, err := connection.Read(keyBuf) + if err == nil { + break + } + } +} + +var forbiddenSlug = []string{ + "ping", +} + +func isForbiddenSlug(slug string) bool { + for _, s := range forbiddenSlug { + if slug == s { + return true + } + } + return false +} diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go new file mode 100644 index 0000000..2038c2a --- /dev/null +++ b/session/lifecycle/lifecycle.go @@ -0,0 +1,124 @@ +package lifecycle + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "time" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type Interaction interface { + SendMessage(string) +} + +type Forwarder interface { + Close() error + GetTunnelType() types.TunnelType +} + +type Lifecycle struct { + Status types.Status + Conn ssh.Conn + Channel ssh.Channel + + Interaction Interaction + Forwarder Forwarder + SlugManager slug.Manager +} + +type SessionLifecycle interface { + Close() error + WaitForRunningStatus() + SetStatus(status types.Status) + GetConnection() ssh.Conn + GetChannel() ssh.Channel + SetChannel(channel ssh.Channel) +} + +func (l *Lifecycle) GetChannel() ssh.Channel { + return l.Channel +} + +func (l *Lifecycle) SetChannel(channel ssh.Channel) { + l.Channel = channel +} +func (l *Lifecycle) GetConnection() ssh.Conn { + return l.Conn +} +func (l *Lifecycle) SetStatus(status types.Status) { + l.Status = status +} +func (l *Lifecycle) WaitForRunningStatus() { + timeout := time.After(3 * time.Second) + ticker := time.NewTicker(150 * time.Millisecond) + defer ticker.Stop() + frames := []string{"-", "\\", "|", "/"} + i := 0 + for { + select { + case <-ticker.C: + l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) + i = (i + 1) % len(frames) + if l.Status == types.RUNNING { + l.Interaction.SendMessage("\r\033[K") + return + } + case <-timeout: + l.Interaction.SendMessage("\r\033[K") + l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") + err := l.Close() + if err != nil { + log.Printf("failed to close session: %v", err) + } + log.Println("Timeout waiting for session to start running") + return + } + } +} + +func (l *Lifecycle) Close() error { + err := l.Forwarder.Close() + if err != nil { + return err + } + //if s.Forwarder.Listener != nil { + // err := s.Forwarder.Listener.Close() + // if err != nil && !errors.Is(err, net.ErrClosed) { + // return err + // } + //} + + if l.Channel != nil { + err := l.Channel.Close() + if err != nil && !errors.Is(err, io.EOF) { + return err + } + } + + if l.Conn != nil { + err := l.Conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + } + + //clientSlug := l.SlugManager.Get() + //if clientSlug != "" { + // unregisterClient(clientSlug) + //} + + //if l.Forwarder.GetType() == "TCP" && s.Forwarder.Listener != nil { + // err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) + // if err != nil { + // return err + // } + //} + + return nil +} diff --git a/session/session.go b/session/session.go index 2a38c6a..b45b932 100644 --- a/session/session.go +++ b/session/session.go @@ -1,105 +1,82 @@ package session import ( - "bytes" "log" "sync" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) -const ( - INITIALIZING Status = "INITIALIZING" - RUNNING Status = "RUNNING" - SETUP Status = "SETUP" -) - -type TunnelType string - -const ( - HTTP TunnelType = "http" - TCP TunnelType = "tcp" -) - -type SessionLifecycle interface { - Close() error - WaitForRunningStatus() -} - -type SessionCloser interface { - Close() error -} - type Session interface { - SessionLifecycle - InteractionController - ForwardingController -} + lifecycle.Lifecycle + interaction.InteractionController + forwarder.ForwardingController -type Lifecycle struct { - Status Status + HandleGlobalRequest(ch <-chan *ssh.Request) + HandleTCPIPForward(req *ssh.Request) + HandleHTTPForward(req *ssh.Request, port uint16) + HandleTCPForward(req *ssh.Request, addr string, port uint16) } type SSHSession struct { - Lifecycle *Lifecycle - Interaction *Interaction - Forwarder *Forwarder - - Conn *ssh.ServerConn - channel ssh.Channel - - slug string - slugMu sync.RWMutex + Lifecycle lifecycle.SessionLifecycle + Interaction interaction.InteractionController + Forwarder forwarder.ForwardingController + SlugManager slug.Manager } func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { - session := SSHSession{ - Lifecycle: &Lifecycle{ - Status: INITIALIZING, - }, - Interaction: &Interaction{ - CommandBuffer: new(bytes.Buffer), - EditMode: false, - EditSlug: "", - channel: nil, - getSlug: nil, - setSlug: nil, - session: nil, - forwarder: nil, - }, - Forwarder: &Forwarder{ - Listener: nil, - TunnelType: "", - ForwardedPort: 0, - getSlug: nil, - setSlug: nil, - }, - Conn: conn, - channel: nil, - slug: "", + slugManager := slug.NewManager() + forwarderManager := &forwarder.Forwarder{ + Listener: nil, + TunnelType: "", + ForwardedPort: 0, + SlugManager: slugManager, } - - session.Forwarder.getSlug = session.GetSlug - session.Forwarder.setSlug = session.SetSlug - session.Interaction.getSlug = session.GetSlug - session.Interaction.setSlug = session.SetSlug - session.Interaction.session = &session - session.Interaction.forwarder = session.Forwarder + interactionManager := &interaction.Interaction{ + CommandBuffer: nil, + EditMode: false, + EditSlug: "", + SlugManager: slugManager, + Forwarder: forwarderManager, + Lifecycle: nil, + } + lifecycleManager := &lifecycle.Lifecycle{ + Status: "", + Conn: conn, + Channel: nil, + Interaction: interactionManager, + Forwarder: forwarderManager, + SlugManager: slugManager, + } + session := &SSHSession{ + Lifecycle: lifecycleManager, + Interaction: interactionManager, + Forwarder: forwarderManager, + SlugManager: slugManager, + } + interactionManager.SetLifecycle(lifecycleManager) go func() { - go session.waitForRunningStatus() + go session.Lifecycle.WaitForRunningStatus() for channel := range sshChan { ch, reqs, _ := channel.Accept() - if session.channel == nil { - session.channel = ch - session.Interaction.channel = ch - session.Lifecycle.Status = SETUP + if session.Lifecycle.GetChannel() == nil { + session.Lifecycle.SetChannel(ch) + session.Interaction.SetChannel(ch) + //session.Interaction.channel = ch + session.Lifecycle.SetStatus(types.SETUP) go session.HandleGlobalRequest(forwardingReq) } go session.HandleGlobalRequest(reqs) } - err := session.Close() + err := session.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -107,14 +84,26 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan }() } -func (s *SSHSession) GetSlug() string { - s.slugMu.RLock() - defer s.slugMu.RUnlock() - return s.slug +var ( + clientsMutex sync.RWMutex + Clients = make(map[string]*SSHSession) +) + +func registerClient(slug string, session *SSHSession) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[slug]; exists { + return false + } + + Clients[slug] = session + return true } -func (s *SSHSession) SetSlug(slug string) { - s.slugMu.Lock() - s.slug = slug - s.slugMu.Unlock() +func unregisterClient(slug string) { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + delete(Clients, slug) } diff --git a/session/slug/slug.go b/session/slug/slug.go new file mode 100644 index 0000000..4900e22 --- /dev/null +++ b/session/slug/slug.go @@ -0,0 +1,32 @@ +package slug + +import "sync" + +type Manager interface { + Get() string + Set(slug string) +} + +type manager struct { + slug string + slugMu sync.RWMutex +} + +func NewManager() Manager { + return &manager{ + slug: "", + slugMu: sync.RWMutex{}, + } +} + +func (s *manager) Get() string { + s.slugMu.RLock() + defer s.slugMu.RUnlock() + return s.slug +} + +func (s *manager) Set(slug string) { + s.slugMu.Lock() + s.slug = slug + s.slugMu.Unlock() +} diff --git a/types/types.go b/types/types.go new file mode 100644 index 0000000..c007661 --- /dev/null +++ b/types/types.go @@ -0,0 +1,16 @@ +package types + +type Status string + +const ( + INITIALIZING Status = "INITIALIZING" + RUNNING Status = "RUNNING" + SETUP Status = "SETUP" +) + +type TunnelType string + +const ( + HTTP TunnelType = "HTTP" + TCP TunnelType = "TCP" +)