diff --git a/internal/port/port.go b/internal/port/port.go index 31aafc5..4622a1b 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -1,4 +1,4 @@ -package port +package SSH_PORT import ( "fmt" @@ -45,10 +45,10 @@ func (pm *PortManager) AddPortRange(startPort, endPort uint16) error { if startPort > endPort { return fmt.Errorf("start port cannot be greater than end port") } - for port := startPort; port <= endPort; port++ { - if _, exists := pm.ports[port]; !exists { - pm.ports[port] = false - pm.sortedPorts = append(pm.sortedPorts, port) + for SSH_PORT := startPort; SSH_PORT <= endPort; SSH_PORT++ { + if _, exists := pm.ports[SSH_PORT]; !exists { + pm.ports[SSH_PORT] = false + pm.sortedPorts = append(pm.sortedPorts, SSH_PORT) } } sort.Slice(pm.sortedPorts, func(i, j int) bool { @@ -61,30 +61,30 @@ func (pm *PortManager) GetUnassignedPort() (uint16, bool) { pm.mu.Lock() defer pm.mu.Unlock() - for _, port := range pm.sortedPorts { - if !pm.ports[port] { - pm.ports[port] = true - return port, true + for _, SSH_PORT := range pm.sortedPorts { + if !pm.ports[SSH_PORT] { + pm.ports[SSH_PORT] = true + return SSH_PORT, true } } return 0, false } -func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error { +func (pm *PortManager) SetPortStatus(SSH_PORT uint16, assigned bool) error { pm.mu.Lock() defer pm.mu.Unlock() - if _, exists := pm.ports[port]; !exists { - return fmt.Errorf("port %d is not in the allowed range", port) + if _, exists := pm.ports[SSH_PORT]; !exists { + return fmt.Errorf("port %d is not in the allowed range", SSH_PORT) } - pm.ports[port] = assigned + pm.ports[SSH_PORT] = assigned return nil } -func (pm *PortManager) GetPortStatus(port uint16) (bool, bool) { +func (pm *PortManager) GetPortStatus(SSH_PORT uint16) (bool, bool) { pm.mu.RLock() defer pm.mu.RUnlock() - status, exists := pm.ports[port] + status, exists := pm.ports[SSH_PORT] return status, exists } diff --git a/main.go b/main.go index 1fb275c..79eb4bb 100644 --- a/main.go +++ b/main.go @@ -1,11 +1,12 @@ package main import ( - "golang.org/x/crypto/ssh" "log" "os" "tunnel_pls/server" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) func main() { @@ -20,7 +21,7 @@ func main() { log.SetOutput(os.Stdout) log.SetFlags(log.LstdFlags | log.Lshortfile) - privateBytes, err := os.ReadFile(utils.Getenv("ssh_private_key")) + privateBytes, err := os.ReadFile(utils.Getenv("SSH_PRIVATE_KEY")) if err != nil { log.Fatalf("Failed to load private key : %s", err.Error()) } diff --git a/server/http.go b/server/http.go index f416b67..7501d61 100644 --- a/server/http.go +++ b/server/http.go @@ -63,7 +63,7 @@ var allowedCors = make(map[string]bool) var isAllowedAllCors = false func init() { - corsList := utils.Getenv("cors_list") + corsList := utils.Getenv("CORS_LIST") if corsList == "*" { isAllowedAllCors = true } else { @@ -86,11 +86,11 @@ func NewHTTPServer() error { } } - listener, err := net.Listen("tcp", ":80") + listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("HTTP_PORT"))) if err != nil { return errors.New("Error listening: " + err.Error()) } - if utils.Getenv("tls_enabled") == "true" && utils.Getenv("tls_redirect") == "true" { + if utils.Getenv("TLS_ENABLED") == "true" && utils.Getenv("TLS_ENABLED") == "true" { redirectTLS = true } go func() { @@ -129,7 +129,7 @@ func Handler(conn net.Conn) { if redirectTLS { conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("domain")) + + fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("DOMAIN")) + "Content-Length: 0\r\n" + "Connection: close\r\n" + "\r\n")) diff --git a/server/https.go b/server/https.go index fbaf3f9..6d5854c 100644 --- a/server/https.go +++ b/server/https.go @@ -16,7 +16,7 @@ import ( ) func NewHTTPSServer() error { - cert, err := tls.LoadX509KeyPair(utils.Getenv("cert_loc"), utils.Getenv("key_loc")) + cert, err := tls.LoadX509KeyPair(utils.Getenv("CERT_LOC"), utils.Getenv("KEY_LOC")) if err != nil { return err } diff --git a/server/server.go b/server/server.go index 3f5e739..0a9b4d2 100644 --- a/server/server.go +++ b/server/server.go @@ -2,11 +2,12 @@ package server import ( "fmt" - "golang.org/x/crypto/ssh" "log" "net" "net/http" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) type Server struct { @@ -16,12 +17,12 @@ type Server struct { } func NewServer(config ssh.ServerConfig) *Server { - listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port"))) + listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("SSH_PORT"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) return nil } - if utils.Getenv("tls_enabled") == "true" { + if utils.Getenv("TLS_ENABLED") == "true" { go func() { err := NewHTTPSServer() if err != nil { diff --git a/session/handler.go b/session/handler.go index c1328a7..e3a65ad 100644 --- a/session/handler.go +++ b/session/handler.go @@ -92,7 +92,6 @@ func (s *Session) Close() error { if s.Listener != nil { err := s.Listener.Close() if err != nil && !errors.Is(err, net.ErrClosed) { - fmt.Println("1") return err } } @@ -100,7 +99,6 @@ func (s *Session) Close() error { if s.ConnChannel != nil { err := s.ConnChannel.Close() if err != nil && !errors.Is(err, io.EOF) { - fmt.Println("2") return err } } @@ -108,8 +106,6 @@ func (s *Session) Close() error { if s.Connection != nil { err := s.Connection.Close() if err != nil && !errors.Is(err, net.ErrClosed) { - fmt.Println("3") - return err } } @@ -121,7 +117,6 @@ func (s *Session) Close() error { if s.TunnelType == TCP { err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false) if err != nil { - fmt.Println("4") return err } } @@ -131,16 +126,35 @@ func (s *Session) Close() error { } func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { - for req := range GlobalRequest { - switch req.Type { - case "tcpip-forward": - s.handleTCPIPForward(req) - return - case "shell", "pty-req", "window-change": - req.Reply(true, nil) - default: - log.Println("Unknown request type:", req.Type) - req.Reply(false, nil) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + for { + select { + case req, ok := <-GlobalRequest: + if !ok || req == nil { + log.Println("GlobalRequest channel closed") + return + } + switch req.Type { + case "tcpip-forward": + cancel() + s.handleTCPIPForward(req) + return + case "shell", "pty-req", "window-change": + req.Reply(true, nil) + default: + log.Println("Unknown request type:", req.Type) + req.Reply(false, nil) + } + case <-ctx.Done(): + if s.Status == SETUP { + s.sendMessage("No forwarding request detected. See https://tunnl.live for setup help.\n\r") + err := s.Close() + if err != nil { + log.Println("Cannot close connection: ", err) + return + } + } } } } @@ -216,15 +230,15 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func isBlockedPort(port uint16) bool { - if port == 80 || port == 443 { +func isBlockedPort(SSH_PORT uint16) bool { + if SSH_PORT == 80 || SSH_PORT == 443 { return false } - if port < 1024 && port != 0 { + if SSH_PORT < 1024 && SSH_PORT != 0 { return true } for _, p := range blockedReservedPorts { - if p == port { + if p == SSH_PORT { return true } } @@ -250,13 +264,13 @@ func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint16) { s.waitForRunningStatus() - domain := utils.Getenv("domain") + DOMAIN := utils.Getenv("DOMAIN") protocol := "http" - if utils.Getenv("tls_enabled") == "true" { + if utils.Getenv("TLS_ENABLED") == "true" { protocol = "https" } - s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, DOMAIN)) req.Reply(true, buf.Bytes()) } @@ -403,13 +417,13 @@ func (s *Session) handleSlugEditMode(connection ssh.Channel, inSlugEditMode *boo if len(*editSlug) > 0 { *editSlug = (*editSlug)[:len(*editSlug)-1] connection.Write([]byte("\r\033[K")) - connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain"))) + connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN"))) } } else if char >= 32 && char <= 126 { if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' { *editSlug += string(char) connection.Write([]byte("\r\033[K")) - connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain"))) + connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN"))) } } } @@ -438,7 +452,7 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e } connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) - connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n")) + connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("DOMAIN") + "\r\n\r\n")) connection.Write([]byte("Press any key to continue...\r\n")) } else if isForbiddenSlug(*editSlug) { connection.Write([]byte("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n")) @@ -457,12 +471,12 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e connection.Write([]byte("\033[H\033[2J")) showWelcomeMessage(connection) - domain := utils.Getenv("domain") + DOMAIN := utils.Getenv("DOMAIN") protocol := "http" - if utils.Getenv("tls_enabled") == "true" { + if utils.Getenv("TLS_ENABLED") == "true" { protocol = "https" } - connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))) + connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN))) *inSlugEditMode = false commandBuffer.Reset() @@ -534,15 +548,15 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd case "/clear": connection.Write([]byte("\033[H\033[2J")) showWelcomeMessage(s.ConnChannel) - domain := utils.Getenv("domain") + DOMAIN := utils.Getenv("DOMAIN") if s.TunnelType == HTTP { protocol := "http" - if utils.Getenv("tls_enabled") == "true" { + if utils.Getenv("TLS_ENABLED") == "true" { protocol = "https" } - s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain)) + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN)) } else { - s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort)) + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, DOMAIN, s.ForwardedPort)) } case "/slug": @@ -553,7 +567,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd *editSlug = s.Slug connection.Write([]byte("\033[H\033[2J")) displaySlugEditor(connection, s.Slug) - connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain"))) + connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN"))) } default: connection.Write([]byte("\r\nUnknown command")) @@ -684,8 +698,8 @@ func showWelcomeMessage(connection ssh.Channel) { } func displaySlugEditor(connection ssh.Channel, currentSlug string) { - domain := utils.Getenv("domain") - fullDomain := currentSlug + "." + domain + DOMAIN := utils.Getenv("DOMAIN") + fullDomain := currentSlug + "." + DOMAIN const paddingRight = 4 @@ -742,15 +756,15 @@ func ParseAddr(addr string) (string, uint32) { log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) return "0.0.0.0", uint32(0) } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) + SSH_PORT, _ := strconv.Atoi(portStr) + return host, uint32(SSH_PORT) } -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { +func createForwardedTCPIPPayload(host string, originPort, SSH_PORT uint16) []byte { var buf bytes.Buffer writeSSHString(&buf, "localhost") - binary.Write(&buf, binary.BigEndian, uint32(port)) + binary.Write(&buf, binary.BigEndian, uint32(SSH_PORT)) writeSSHString(&buf, host) binary.Write(&buf, binary.BigEndian, uint32(originPort)) diff --git a/session/session.go b/session/session.go index 63177c5..0f53efc 100644 --- a/session/session.go +++ b/session/session.go @@ -1,9 +1,10 @@ package session import ( - "golang.org/x/crypto/ssh" "net" "sync" + + "golang.org/x/crypto/ssh" ) type TunnelType string @@ -44,7 +45,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session { ch, reqs, _ := channel.Accept() if session.ConnChannel == nil { session.ConnChannel = ch - session.Status = RUNNING go session.HandleGlobalRequest(forwardingReq) } go session.HandleGlobalRequest(reqs)