diff --git a/main.go b/main.go index 3c5d5c7..abbea0a 100644 --- a/main.go +++ b/main.go @@ -9,12 +9,16 @@ import ( func main() { sshConfig := &ssh.ServerConfig{ - NoClientAuth: true, + NoClientAuth: true, + ServerVersion: "SSH-2.0-TunnlPls-1.0", PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { return nil, nil }, } + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + privateBytes, err := os.ReadFile("id_rsa") if err != nil { log.Fatal("Failed to load private key (./id_rsa)") diff --git a/server/handler.go b/server/handler.go index 51ca861..d167a55 100644 --- a/server/handler.go +++ b/server/handler.go @@ -1,7 +1,6 @@ package server import ( - "fmt" "golang.org/x/crypto/ssh" "log" "net" @@ -16,7 +15,7 @@ func (s *Server) handleConnection(conn net.Conn) { return } - fmt.Println("SSH connection established:", sshConn.User()) + log.Println("SSH connection established:", sshConn.User()) session.New(sshConn, chans, reqs) } diff --git a/server/http.go b/server/http.go index 7190a6e..46a00d5 100644 --- a/server/http.go +++ b/server/http.go @@ -46,14 +46,14 @@ func Handler(conn net.Conn) { reader := bufio.NewReader(conn) headers, err := peekUntilHeaders(reader, 8192) if err != nil { - fmt.Println("Failed to peek headers:", err) + log.Println("Failed to peek headers:", err) return } host := strings.Split(parseHostFromHeader(headers), ".") if len(host) < 1 { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - fmt.Println("Bad Request") + log.Println("Bad Request") conn.Close() return } diff --git a/server/https.go b/server/https.go index 28dc29a..d9407f5 100644 --- a/server/https.go +++ b/server/https.go @@ -4,7 +4,6 @@ import ( "bufio" "crypto/tls" "errors" - "fmt" "golang.org/x/net/context" "log" "net" @@ -46,7 +45,7 @@ func HandlerTLS(conn net.Conn) { reader := bufio.NewReader(conn) headers, err := peekUntilHeaders(reader, 8192) if err != nil { - fmt.Println("Failed to peek headers:", err) + log.Println("Failed to peek headers:", err) return } diff --git a/server/server.go b/server/server.go index bc5079b..f5df009 100644 --- a/server/server.go +++ b/server/server.go @@ -16,7 +16,7 @@ type Server struct { } func NewServer(config ssh.ServerConfig) *Server { - listener, err := net.Listen("tcp", ":2200") + listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) return nil @@ -45,7 +45,7 @@ func NewServer(config ssh.ServerConfig) *Server { } func (s *Server) Start() { - fmt.Println("SSH server is starting on port 2200...") + log.Println("SSH server is starting on port 2200...") for { conn, err := (*s.Conn).Accept() if err != nil { diff --git a/session/handler.go b/session/handler.go index 99353f7..178e3b1 100644 --- a/session/handler.go +++ b/session/handler.go @@ -111,9 +111,11 @@ func (s *Session) Close() { } func (s *Session) handleGlobalRequest() { + ticker := time.NewTicker(1 * time.Second) for { select { case req := <-s.GlobalRequest: + ticker.Stop() if req == nil { return } @@ -124,6 +126,9 @@ func (s *Session) handleGlobalRequest() { } case <-s.Done: return + case <-ticker.C: + s.sendMessage(fmt.Sprintf("Please specify the forwarding tunnel. For example: 'ssh %s -p %s -R 443:localhost:8080' \r\n\n\n", utils.Getenv("domain"), utils.Getenv("port"))) + s.Close() } } } @@ -137,16 +142,30 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { if err != nil { log.Println("Failed to read address from payload:", err) req.Reply(false, nil) + s.Close() return } var portToBind uint32 if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil { log.Println("Failed to read port from payload:", err) + s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) req.Reply(false, nil) + s.Close() return } + if isBlockedPort(portToBind) { + s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) + req.Reply(false, nil) + s.Close() + return + } + + s.sendMessage("\033[H\033[2J") + showWelcomeMessage(s.ConnChannels[0]) + s.Status = RUNNING + if portToBind == 80 || portToBind == 443 { s.handleHTTPForward(req, portToBind) return @@ -155,6 +174,23 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { s.handleTCPForward(req, addr, portToBind) } +var blockedReservedPorts = []uint32{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} + +func isBlockedPort(port uint32) bool { + if port == 80 || port == 443 { + return false + } + if port < 1024 { + return true + } + for _, p := range blockedReservedPorts { + if p == port { + return true + } + } + return false +} + func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) { s.TunnelType = HTTP s.ForwardedPort = uint16(portToBind) @@ -190,13 +226,14 @@ func (s *Session) handleTCPForward(req *ssh.Request, addr string, portToBind uin listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - log.Printf("Failed to bind to port %d: %v", portToBind, err) + s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) req.Reply(false, nil) + s.Close() return } s.Listener = listener s.ForwardedPort = uint16(portToBind) - s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort)) + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort)) go s.acceptTCPConnections() @@ -466,7 +503,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd } s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain)) } else { - s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, domain, s.ForwardedPort)) + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort)) } case "/slug": @@ -487,10 +524,6 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd } func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) { - connection.Write([]byte("\033[H\033[2J")) - showWelcomeMessage(connection) - s.Status = RUNNING - go s.handleGlobalRequest() for req := range requests {