diff --git a/server/http.go b/server/http.go index 1676cec..7190a6e 100644 --- a/server/http.go +++ b/server/http.go @@ -8,6 +8,7 @@ import ( "golang.org/x/net/context" "log" "net" + "strconv" "strings" "time" "tunnel_pls/session" @@ -42,10 +43,8 @@ func NewHTTPServer() error { } func Handler(conn net.Conn) { - //TODO: Determain deadline time/set custom timeout on env - ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 512) + headers, err := peekUntilHeaders(reader, 8192) if err != nil { fmt.Println("Failed to peek headers:", err) return @@ -61,7 +60,6 @@ func Handler(conn net.Conn) { if len(host) < 1 { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - fmt.Println("Bad Request") conn.Close() return } @@ -80,16 +78,27 @@ func Handler(conn net.Conn) { sshSession, ok := session.Clients[slug] if !ok { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - fmt.Println("Bad Request 1") conn.Close() return } + keepalive, timeout := parseConnectionDetails(headers) + var ctx context.Context + var cancel context.CancelFunc + if keepalive { + if timeout >= 300 { + timeout = 300 + } + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second)) + } else { + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + } sshSession.HandleForwardedConnection(session.UserConnection{ Reader: reader, Writer: conn, Context: ctx, - }, sshSession.Connection, 80) + Cancel: cancel, + }, sshSession.Connection) return } @@ -122,3 +131,42 @@ func parseHostFromHeader(data []byte) string { } return "" } + +func parseConnectionDetails(data []byte) (keepAlive bool, timeout int) { + keepAlive = false + timeout = 30 + + lines := strings.Split(string(data), "\r\n") + + for _, line := range lines { + if strings.HasPrefix(strings.ToLower(line), "connection:") { + value := strings.TrimSpace(strings.TrimPrefix(strings.ToLower(line), "connection:")) + keepAlive = (value == "keep-alive") + break + } + } + + if keepAlive { + for _, line := range lines { + if strings.HasPrefix(strings.ToLower(line), "keep-alive:") { + value := strings.TrimSpace(strings.TrimPrefix(line, "Keep-Alive:")) + + if strings.Contains(value, "timeout=") { + parts := strings.Split(value, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "timeout=") { + timeoutStr := strings.TrimPrefix(part, "timeout=") + if t, err := strconv.Atoi(timeoutStr); err == nil { + timeout = t + } + } + } + } + break + } + } + } + + return keepAlive, timeout +} diff --git a/server/https.go b/server/https.go index 53b0a7d..28dc29a 100644 --- a/server/https.go +++ b/server/https.go @@ -43,9 +43,8 @@ func NewHTTPSServer() error { } func HandlerTLS(conn net.Conn) { - ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 512) + headers, err := peekUntilHeaders(reader, 8192) if err != nil { fmt.Println("Failed to peek headers:", err) return @@ -54,14 +53,12 @@ func HandlerTLS(conn net.Conn) { host := strings.Split(parseHostFromHeader(headers), ".") if len(host) < 1 { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - fmt.Println("Bad Request") conn.Close() return } if len(host) < 1 { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - fmt.Println("Bad Request") conn.Close() return } @@ -70,15 +67,26 @@ func HandlerTLS(conn net.Conn) { sshSession, ok := session.Clients[slug] if !ok { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - fmt.Println("Bad Request 1") conn.Close() return } + keepalive, timeout := parseConnectionDetails(headers) + var ctx context.Context + var cancel context.CancelFunc + if keepalive { + if timeout >= 300 { + timeout = 300 + } + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second)) + } else { + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + } sshSession.HandleForwardedConnection(session.UserConnection{ Reader: reader, Writer: conn, Context: ctx, - }, sshSession.Connection, 80) + Cancel: cancel, + }, sshSession.Connection) return } diff --git a/session/handler.go b/session/handler.go index ecb56d5..99353f7 100644 --- a/session/handler.go +++ b/session/handler.go @@ -6,21 +6,108 @@ import ( "encoding/binary" "errors" "fmt" - "golang.org/x/crypto/ssh" - "golang.org/x/net/context" "io" "log" "net" "strconv" "strings" + "sync" "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/net/context" "tunnel_pls/utils" ) +type SessionStatus string + +const ( + INITIALIZING SessionStatus = "INITIALIZING" + RUNNING SessionStatus = "RUNNING" + SETUP SessionStatus = "SETUP" +) + type UserConnection struct { Reader io.Reader Writer net.Conn Context context.Context + Cancel context.CancelFunc +} + +var ( + clientsMutex sync.RWMutex + Clients = make(map[string]*Session) +) + +type Session struct { + Connection *ssh.ServerConn + ConnChannels []ssh.Channel + GlobalRequest <-chan *ssh.Request + Listener net.Listener + TunnelType TunnelType + ForwardedPort uint16 + Status SessionStatus + Slug string + SlugChannel chan bool + Done chan bool +} + +func registerClient(slug string, session *Session) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[slug]; exists { + return false + } + + Clients[slug] = session + return true +} + +func unregisterClient(slug string) { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + delete(Clients, slug) +} + +func updateClientSlug(oldSlug, newSlug string) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { + return false + } + + client, ok := Clients[oldSlug] + if !ok { + return false + } + + delete(Clients, oldSlug) + client.Slug = newSlug + Clients[newSlug] = client + return true +} + +func (s *Session) Close() { + if s.Listener != nil { + s.Listener.Close() + } + + for _, ch := range s.ConnChannels { + ch.Close() + } + + if s.Connection != nil { + s.Connection.Close() + } + + if s.Slug != "" { + unregisterClient(s.Slug) + } + + close(s.Done) } func (s *Session) handleGlobalRequest() { @@ -32,12 +119,11 @@ func (s *Session) handleGlobalRequest() { } if req.Type == "tcpip-forward" { s.handleTCPIPForward(req) - continue } else { req.Reply(false, nil) } case <-s.Done: - break + return } } } @@ -55,7 +141,6 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { } var portToBind uint32 - if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil { log.Println("Failed to read port from payload:", err) req.Reply(false, nil) @@ -63,81 +148,435 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { } if portToBind == 80 || portToBind == 443 { - s.TunnelType = HTTP - s.ForwardedPort = uint16(portToBind) - var slug string - for { - slug = utils.GenerateRandomString(32) - if _, ok := Clients[slug]; ok { + s.handleHTTPForward(req, portToBind) + return + } + + s.handleTCPForward(req, addr, portToBind) +} + +func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) { + s.TunnelType = HTTP + s.ForwardedPort = uint16(portToBind) + + slug := s.generateUniqueSlug() + if slug == "" { + req.Reply(false, nil) + return + } + + s.Slug = slug + registerClient(slug, s) + + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint32(80)) + log.Printf("HTTP forwarding approved on port: %d", 80) + + s.waitForRunningStatus() + + domain := utils.Getenv("domain") + protocol := "http" + if utils.Getenv("tls_enabled") == "true" { + protocol = "https" + } + + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + req.Reply(true, buf.Bytes()) +} + +func (s *Session) handleTCPForward(req *ssh.Request, addr string, portToBind uint32) { + s.TunnelType = TCP + log.Printf("Requested forwarding on %s:%d", addr, portToBind) + + listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) + if err != nil { + log.Printf("Failed to bind to port %d: %v", portToBind, err) + req.Reply(false, nil) + return + } + s.Listener = listener + s.ForwardedPort = uint16(portToBind) + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort)) + + go s.acceptTCPConnections() + + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint32(portToBind)) + log.Printf("TCP forwarding approved on port: %d", portToBind) + req.Reply(true, buf.Bytes()) +} + +func (s *Session) acceptTCPConnections() { + for { + conn, err := s.Listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { return } + log.Printf("Error accepting connection: %v", err) + continue + } + + go s.HandleForwardedConnection(UserConnection{ + Reader: nil, + Writer: conn, + Context: context.Background(), + }, s.Connection) + } +} + +func (s *Session) generateUniqueSlug() string { + maxAttempts := 5 + + for i := 0; i < maxAttempts; i++ { + slug := utils.GenerateRandomString(20) + + clientsMutex.RLock() + _, exists := Clients[slug] + clientsMutex.RUnlock() + + if !exists { + return slug + } + } + + log.Println("Failed to generate unique slug after multiple attempts") + return "" +} + +func (s *Session) waitForRunningStatus() { + timeout := time.After(10 * time.Second) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if s.Status == RUNNING { + return + } + case <-timeout: + log.Println("Timeout waiting for session to start running") + return + } + } +} + +func (s *Session) sendMessage(message string) { + if len(s.ConnChannels) > 0 { + s.ConnChannels[0].Write([]byte(message)) + } +} + +func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) { + connection, requests, err := newChannel.Accept() + if err != nil { + log.Printf("Could not accept channel: %s", err) + return + } + + s.ConnChannels = append(s.ConnChannels, connection) + + go s.handleUserInput(connection) + + go s.handleChannelRequests(connection, requests) +} + +func (s *Session) handleUserInput(connection ssh.Channel) { + var commandBuffer bytes.Buffer + buf := make([]byte, 1) + inSlugEditMode := false + editSlug := s.Slug + + for { + n, err := connection.Read(buf) + if err != nil { + if err != io.EOF { + log.Printf("Error reading from client: %s", err) + } break } - Clients[slug] = s - s.Slug = slug - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint32(80)) - log.Printf("Forwarding approved on port: %d", 80) - //TODO: fix status checking later - for s.Status != RUNNING { - time.Sleep(500 * time.Millisecond) - } - if utils.Getenv("tls_enabled") == "true" { - s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to https://%s.%s \r\n", slug, utils.Getenv("domain")))) - } else { - s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", slug, utils.Getenv("domain")))) - } - req.Reply(true, buf.Bytes()) + if n > 0 { + char := buf[0] - } else { - s.TunnelType = TCP - log.Printf("Requested forwarding on %s:%d", addr, portToBind) + if inSlugEditMode { + s.handleSlugEditMode(connection, &inSlugEditMode, &editSlug, char, &commandBuffer) + continue + } - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) - if err != nil { - log.Printf("Failed to bind to port %d: %v", portToBind, err) - req.Reply(false, nil) - return - } - s.Listener = listener - s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to %s:%d \r\n", utils.Getenv("domain"), portToBind))) - go func() { - for { - fmt.Println("jalan di bawah") - conn, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) + connection.Write(buf[:n]) + + if char == 8 || char == 127 { + if commandBuffer.Len() > 0 { + commandBuffer.Truncate(commandBuffer.Len() - 1) + connection.Write([]byte("\b \b")) + } + continue + } + + if char == '/' { + commandBuffer.Reset() + commandBuffer.WriteByte(char) + continue + } + + if commandBuffer.Len() > 0 { + if char == 13 { + s.handleCommand(connection, commandBuffer.String(), &inSlugEditMode, &editSlug, &commandBuffer) continue } - - go s.HandleForwardedConnection(UserConnection{ - Reader: nil, - Writer: conn, - Context: context.Background(), - }, s.Connection, portToBind) + commandBuffer.WriteByte(char) } - }() + } + } +} - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint32(portToBind)) +func (s *Session) handleSlugEditMode(connection ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, commandBuffer *bytes.Buffer) { + if char == 13 { + s.handleSlugSave(connection, inSlugEditMode, editSlug, commandBuffer) + } else if char == 27 { + s.handleSlugCancel(connection, inSlugEditMode, commandBuffer) + } else if char == 8 || char == 127 { + if len(*editSlug) > 0 { + *editSlug = (*editSlug)[:len(*editSlug)-1] + connection.Write([]byte("\r\033[K")) + connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain"))) + } + } else if char >= 32 && char <= 126 { + if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' { + *editSlug += string(char) + connection.Write([]byte("\r\033[K")) + connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain"))) + } + } +} - log.Printf("Forwarding approved on port: %d", portToBind) - req.Reply(true, buf.Bytes()) +func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, editSlug *string, commandBuffer *bytes.Buffer) { + isValid := isValidSlug(*editSlug) + + connection.Write([]byte("\033[H\033[2J")) + + if isValid { + oldSlug := s.Slug + newSlug := *editSlug + + if !updateClientSlug(oldSlug, newSlug) { + handleSlugUpdateError(connection, s) + return + } + + connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) + connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n")) + connection.Write([]byte("Press any key to continue...\r\n")) + } else { + connection.Write([]byte("\r\n\r\n❌ INVALID SUBDOMAIN ❌\r\n\r\n")) + connection.Write([]byte("Use only lowercase letters, numbers, and hyphens.\r\n")) + connection.Write([]byte("Length must be 3-20 characters and cannot start or end with a hyphen.\r\n\r\n")) + connection.Write([]byte("Press any key to continue...\r\n")) } + waitForKeyPress(connection) + + connection.Write([]byte("\033[H\033[2J")) + showWelcomeMessage(connection) + + domain := utils.Getenv("domain") + protocol := "http" + if utils.Getenv("tls_enabled") == "true" { + protocol = "https" + } + connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))) + + *inSlugEditMode = false + commandBuffer.Reset() +} + +func (s *Session) handleSlugCancel(connection ssh.Channel, inSlugEditMode *bool, commandBuffer *bytes.Buffer) { + *inSlugEditMode = false + connection.Write([]byte("\033[H\033[2J")) + connection.Write([]byte("\r\n\r\n⚠️ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n")) + connection.Write([]byte("Press any key to continue...\r\n")) + + waitForKeyPress(connection) + + connection.Write([]byte("\033[H\033[2J")) + showWelcomeMessage(connection) + + commandBuffer.Reset() +} + +func handleSlugUpdateError(connection ssh.Channel, s *Session) { + connection.Write([]byte("\r\n\r\n❌ SERVER ERROR ❌\r\n\r\n")) + connection.Write([]byte("Failed to update subdomain. You will be disconnected in 5 seconds.\r\n\r\n")) + + for i := 5; i > 0; i-- { + connection.Write([]byte(fmt.Sprintf("Disconnecting in %d...\r\n", i))) + time.Sleep(1 * time.Second) + } + + s.Close() +} + +func isValidSlug(slug string) bool { + if len(slug) < 3 || len(slug) > 20 { + return false + } + + if slug[0] == '-' || slug[len(slug)-1] == '-' { + return false + } + + for _, c := range slug { + if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { + return false + } + } + + return true +} + +func waitForKeyPress(connection ssh.Channel) { + keyBuf := make([]byte, 1) + for { + _, err := connection.Read(keyBuf) + if err == nil { + break + } + } +} + +func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, commandBuffer *bytes.Buffer) { + switch command { + case "/bye": + connection.Write([]byte("\r\nClosing connection...")) + s.Close() + case "/debug": + log.Println("Client registry:", Clients) + case "/help": + connection.Write([]byte("\r\nAvailable commands: /bye, /help, /clear, /slug")) + case "/clear": + connection.Write([]byte("\033[H\033[2J")) + showWelcomeMessage(s.ConnChannels[0]) + domain := utils.Getenv("domain") + if s.TunnelType == HTTP { + protocol := "http" + if utils.Getenv("tls_enabled") == "true" { + protocol = "https" + } + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain)) + } else { + s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, domain, s.ForwardedPort)) + } + + case "/slug": + if s.TunnelType != HTTP { + connection.Write([]byte(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", s.TunnelType))) + } else { + *inSlugEditMode = true + *editSlug = s.Slug + connection.Write([]byte("\033[H\033[2J")) + displaySlugEditor(connection, s.Slug) + connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain"))) + } + default: + connection.Write([]byte("\r\nUnknown command")) + } + + commandBuffer.Reset() +} + +func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) { + connection.Write([]byte("\033[H\033[2J")) + showWelcomeMessage(connection) + s.Status = RUNNING + + go s.handleGlobalRequest() + + for req := range requests { + switch req.Type { + case "shell", "pty-req", "window-change": + req.Reply(true, nil) + default: + log.Println("Unknown request type:", req.Type) + req.Reply(false, nil) + } + } +} + +func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) { + defer conn.Writer.Close() + + log.Printf("Handling new forwarded connection from %s", conn.Writer.RemoteAddr()) + host, originPort := ParseAddr(conn.Writer.RemoteAddr().String()) + + timestamp := time.Now().Format("02/Jan/2006 15:04:05") + + payload := createForwardedTCPIPPayload(host, uint16(originPort), s.ForwardedPort) + channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + sendBadGatewayResponse(conn.Writer) + return + } + defer channel.Close() + + go handleChannelRequests(reqs, conn, channel, s.SlugChannel) + + if conn.Reader == nil { + conn.Reader = bufio.NewReader(conn.Writer) + } + + go io.Copy(channel, conn.Reader) + + reader := bufio.NewReader(channel) + _, err = reader.Peek(1) + if err == io.EOF { + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, "Could not forward request to the tunnel addr")) + sendBadGatewayResponse(conn.Writer) + conn.Writer.Close() + channel.Close() + return + } + + s.sendMessage(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) + + io.Copy(conn.Writer, reader) +} + +func handleChannelRequests(reqs <-chan *ssh.Request, conn UserConnection, channel ssh.Channel, slugChannel <-chan bool) { + select { + case <-reqs: + for req := range reqs { + req.Reply(false, nil) + } + case <-conn.Context.Done(): + conn.Writer.Close() + channel.Close() + log.Println("Connection closed by timeout") + return + case <-slugChannel: + conn.Writer.Close() + channel.Close() + log.Println("Connection closed by slug change") + return + } +} + +func sendBadGatewayResponse(writer io.Writer) { + response := "HTTP/1.1 502 Bad Gateway\r\n" + + "Content-Length: 11\r\n" + + "Content-Type: text/plain\r\n\r\n" + + "Bad Gateway" + io.Copy(writer, bytes.NewReader([]byte(response))) } func showWelcomeMessage(connection ssh.Channel) { - fmt.Println("jalan nih") asciiArt := []string{ - ` _______ ____ `, - `|_ __| | | | __ \| | `, - ` | |_ __ _ ___| | | |__) | |___ `, - ` | | | | | '_ \| '_ \ / \ | | __/| / __|`, + ` _______ _ _____ _ `, + `|__ __| | | | __ \| | `, + ` | |_ _ _ __ _ __ ___| | | |__) | |___ `, + ` | | | | | '_ \| '_ \ / _ \ | | ___/| / __|`, ` | | |_| | | | | | | | __/ | | | | \__ \`, ` |_|\__,_|_| |_|_| |_|\___|_| |_| |_|___/`, ``, @@ -158,319 +597,51 @@ func showWelcomeMessage(connection ssh.Channel) { } func displaySlugEditor(connection ssh.Channel, currentSlug string) { + domain := utils.Getenv("domain") + fullDomain := currentSlug + "." + domain + + const paddingRight = 4 + + contentLine := " ║ Current: " + fullDomain + boxWidth := len(contentLine) + paddingRight + 1 + if boxWidth < 50 { + boxWidth = 50 + } + + topBorder := " ╔" + strings.Repeat("═", boxWidth-4) + "╗\r\n" + title := centerText("SUBDOMAIN EDITOR", boxWidth-4) + header := " ║" + title + "║\r\n" + midBorder := " ╠" + strings.Repeat("═", boxWidth-4) + "╣\r\n" + emptyLine := " ║" + strings.Repeat(" ", boxWidth-4) + "║\r\n" + + currentLineContent := fmt.Sprintf(" ║ Current: %s", fullDomain) + currentLine := currentLineContent + strings.Repeat(" ", boxWidth-len(currentLineContent)+1) + "║\r\n" + + newLine := " ║ New:" + strings.Repeat(" ", boxWidth-10) + "║\r\n" + saveCancel := " ║ [Enter] Save | [Esc] Cancel" + strings.Repeat(" ", boxWidth-35) + "║\r\n" + bottomBorder := " ╚" + strings.Repeat("═", boxWidth-4) + "╝\r\n" + + connection.Write([]byte("\r\n\r\n")) + connection.Write([]byte(topBorder)) + connection.Write([]byte(header)) + connection.Write([]byte(midBorder)) + connection.Write([]byte(emptyLine)) + connection.Write([]byte(currentLine)) + connection.Write([]byte(emptyLine)) + connection.Write([]byte(newLine)) + connection.Write([]byte(emptyLine)) + connection.Write([]byte(midBorder)) + connection.Write([]byte(saveCancel)) + connection.Write([]byte(bottomBorder)) connection.Write([]byte("\r\n\r\n")) - connection.Write([]byte(" ╔══════════════════════════════════════════════╗\r\n")) - connection.Write([]byte(" ║ SUBDOMAIN EDITOR ║\r\n")) - connection.Write([]byte(" ╠══════════════════════════════════════════════╣\r\n")) - connection.Write([]byte(" ║ ║\r\n")) - connection.Write([]byte(" ║ Current: " + currentSlug + "." + utils.Getenv("domain") + strings.Repeat(" ", max(0, 30-len(currentSlug)-len(utils.Getenv("domain")))) + "║\r\n")) - connection.Write([]byte(" ║ ║\r\n")) - connection.Write([]byte(" ║ New: ║\r\n")) - connection.Write([]byte(" ║ ║\r\n")) - connection.Write([]byte(" ╠══════════════════════════════════════════════╣\r\n")) - connection.Write([]byte(" ║ [Enter] Save | [Esc] Cancel ║\r\n")) - connection.Write([]byte(" ╚══════════════════════════════════════════════╝\r\n\r\n")) } -func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) { - connection, requests, err := newChannel.Accept() - s.ConnChannels = append(s.ConnChannels, connection) - if err != nil { - log.Printf("Could not accept channel: %s", err) - return +func centerText(text string, width int) string { + padding := (width - len(text)) / 2 + if padding < 0 { + padding = 0 } - go func() { - var commandBuffer bytes.Buffer - buf := make([]byte, 1) - inSlugEditMode := false - editSlug := s.Slug - - for { - n, err := connection.Read(buf) - if n > 0 { - char := buf[0] - - if inSlugEditMode { - if char == 13 { - isValid := true - if len(editSlug) < 3 || len(editSlug) > 20 { - isValid = false - } else { - for _, c := range editSlug { - if !((c >= 'a' && c <= 'z') || - (c >= '0' && c <= '9') || - c == '-') { - isValid = false - break - } - } - if editSlug[0] == '-' || editSlug[len(editSlug)-1] == '-' { - isValid = false - } - } - - connection.Write([]byte("\033[H\033[2J")) - - if isValid { - oldSlug := s.Slug - newSlug := editSlug - - client, ok := Clients[oldSlug] - if !ok { - connection.Write([]byte("\r\n\r\n❌ SERVER ERROR ❌\r\n\r\n")) - connection.Write([]byte("Failed to update subdomain. You will be disconnected in 5 seconds.\r\n\r\n")) - - for i := 5; i > 0; i-- { - connection.Write([]byte(fmt.Sprintf("Disconnecting in %d...\r\n", i))) - time.Sleep(1 * time.Second) - } - - s.Close() - return - } - - if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { - connection.Write([]byte("\r\n\r\n❌ SUBDOMAIN ALREADY IN USE ❌\r\n\r\n")) - connection.Write([]byte("This subdomain is already taken. Please try another one.\r\n\r\n")) - connection.Write([]byte("Press any key to continue...\r\n")) - - waitForKeyPress := true - for waitForKeyPress { - keyBuf := make([]byte, 1) - _, err := connection.Read(keyBuf) - if err == nil { - waitForKeyPress = false - } - } - - connection.Write([]byte("\033[H\033[2J")) - inSlugEditMode = true - editSlug = oldSlug - - displaySlugEditor(connection, oldSlug) - connection.Write([]byte("➤ " + editSlug + "." + utils.Getenv("domain"))) - continue - } - - delete(Clients, oldSlug) - client.Slug = newSlug - //TODO: uneceserry channel - client.SlugChannel <- true - Clients[newSlug] = client - - connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) - connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n")) - connection.Write([]byte("Press any key to continue...\r\n")) - } else { - connection.Write([]byte("\r\n\r\n❌ INVALID SUBDOMAIN ❌\r\n\r\n")) - connection.Write([]byte("Use only lowercase letters, numbers, and hyphens.\r\n")) - connection.Write([]byte("Length must be 3-20 characters and cannot start or end with a hyphen.\r\n\r\n")) - connection.Write([]byte("Press any key to continue...\r\n")) - } - - waitForKeyPress := true - for waitForKeyPress { - keyBuf := make([]byte, 1) - _, err := connection.Read(keyBuf) - if err == nil { - waitForKeyPress = false - } - } - - connection.Write([]byte("\033[H\033[2J")) - showWelcomeMessage(connection) - if utils.Getenv("tls_enabled") == "true" { - s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to https://%s.%s \r\n", s.Slug, utils.Getenv("domain")))) - } else { - s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", s.Slug, utils.Getenv("domain")))) - } - - inSlugEditMode = false - commandBuffer.Reset() - continue - } else if char == 27 { - inSlugEditMode = false - connection.Write([]byte("\033[H\033[2J")) - connection.Write([]byte("\r\n\r\n⚠️ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n")) - connection.Write([]byte("Press any key to continue...\r\n")) - - waitForKeyPress := true - for waitForKeyPress { - keyBuf := make([]byte, 1) - _, err := connection.Read(keyBuf) - if err == nil { - waitForKeyPress = false - } - } - - connection.Write([]byte("\033[H\033[2J")) - showWelcomeMessage(connection) - - commandBuffer.Reset() - continue - } else if char == 8 || char == 127 { - if len(editSlug) > 0 { - editSlug = editSlug[:len(editSlug)-1] - connection.Write([]byte("\r\033[K")) - connection.Write([]byte("➤ " + editSlug + "." + utils.Getenv("domain"))) - } - continue - } else if char >= 32 && char <= 126 { - if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' { - editSlug += string(char) - connection.Write([]byte("\r\033[K")) - connection.Write([]byte("➤ " + editSlug + "." + utils.Getenv("domain"))) - } - continue - } - continue - } - - connection.Write(buf[:n]) - - if char == 8 || char == 127 { - if commandBuffer.Len() > 0 { - commandBuffer.Truncate(commandBuffer.Len() - 1) - connection.Write([]byte("\b \b")) - } - continue - } - - if char == '/' { - commandBuffer.Reset() - commandBuffer.WriteByte(char) - continue - } - - if commandBuffer.Len() > 0 { - if char == 13 { - command := commandBuffer.String() - fmt.Println("User entered command:", command, "<>") - - if command == "/bye" { - fmt.Println("Closing connection...") - s.Close() - break - } else if command == "/debug" { - fmt.Println(Clients) - } else if command == "/help" { - connection.Write([]byte("\r\nAvailable commands: /bye, /help, /clear, /slug")) - } else if command == "/clear" { - connection.Write([]byte("\033[H\033[2J")) - } else if command == "/slug" { - if s.TunnelType != HTTP { - connection.Write([]byte(fmt.Sprintf("%s cannot be edited", s.TunnelType))) - continue - } - inSlugEditMode = true - editSlug = s.Slug - - connection.Write([]byte("\033[H\033[2J")) - - connection.Write([]byte("\r\n\r\n")) - connection.Write([]byte(" ╔══════════════════════════════════════════════╗\r\n")) - connection.Write([]byte(" ║ SUBDOMAIN EDITOR ║\r\n")) - connection.Write([]byte(" ╠══════════════════════════════════════════════╣\r\n")) - connection.Write([]byte(" ║ ║\r\n")) - connection.Write([]byte(" ║ Current: " + s.Slug + "." + utils.Getenv("domain") + "║\r\n")) - connection.Write([]byte(" ║ ║\r\n")) - connection.Write([]byte(" ║ New: ║\r\n")) - connection.Write([]byte(" ║ ║\r\n")) - connection.Write([]byte(" ╠══════════════════════════════════════════════╣\r\n")) - connection.Write([]byte(" ║ [Enter] Save | [Esc] Cancel ║\r\n")) - connection.Write([]byte(" ╚══════════════════════════════════════════════╝\r\n\r\n")) - - connection.Write([]byte("➤ " + editSlug + "." + utils.Getenv("domain"))) - } else { - connection.Write([]byte("\r\nUnknown command")) - } - - commandBuffer.Reset() - continue - } - - commandBuffer.WriteByte(char) - continue - } - } - - if err != nil { - if err != io.EOF { - log.Printf("Error reading from client: %s", err) - } - break - } - } - }() - - go func() { - connection.Write([]byte("\033[H\033[2J")) - showWelcomeMessage(connection) - s.Status = RUNNING - - go s.handleGlobalRequest() - - for req := range requests { - switch req.Type { - case "shell", "pty-req", "window-change": - req.Reply(true, nil) - default: - fmt.Println("Unknown request type of : ", req.Type) - req.Reply(false, nil) - } - } - }() -} - -func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn, port uint32) { - defer conn.Writer.Close() - - log.Printf("Handling new forwarded connection from %s", conn.Writer.RemoteAddr()) - host, originPort := ParseAddr(conn.Writer.RemoteAddr().String()) - s.ConnChannels[0].Write([]byte(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, time.Now().Format("02/Jan/2006 15:04:05")))) - - payload := createForwardedTCPIPPayload(host, uint16(originPort), uint16(port)) - channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - io.Copy(conn.Writer, bytes.NewReader([]byte("HTTP/1.1 502 Bad Gateway\r\nContent-Length: 11\r\nContent-Type: text/plain\r\n\r\nBad Gateway"))) - return - } - defer channel.Close() - - go func() { - select { - case <-reqs: - for req := range reqs { - req.Reply(false, nil) - } - case <-conn.Context.Done(): - conn.Writer.Close() - channel.Close() - fmt.Println("cancel by timeout") - return - case <-s.SlugChannel: - conn.Writer.Close() - channel.Close() - fmt.Println("cancel by slug") - return - } - }() - - defer channel.Close() - if conn.Reader == nil { - conn.Reader = bufio.NewReader(conn.Writer) - } - - go io.Copy(channel, conn.Reader) - reader := bufio.NewReader(channel) - _, err = reader.Peek(1) - if err == io.EOF { - s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr 1\r\n")) - return - } - - io.Copy(conn.Writer, reader) + return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding) } func writeSSHString(buffer *bytes.Buffer, str string) { diff --git a/session/session.go b/session/session.go index ce565b2..9e05b72 100644 --- a/session/session.go +++ b/session/session.go @@ -1,34 +1,9 @@ package session import ( - "fmt" - "github.com/google/uuid" "golang.org/x/crypto/ssh" - "net" ) -type STATUS string - -const ( - RUNNING STATUS = "running" - SETUP STATUS = "setup" -) - -type Session struct { - ID uuid.UUID - Slug string - Status STATUS - ConnChannels []ssh.Channel - Connection *ssh.ServerConn - GlobalRequest <-chan *ssh.Request - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - Done chan bool - ForwardedChannel ssh.Channel - SlugChannel chan bool -} - type TunnelType string const ( @@ -38,15 +13,8 @@ const ( UNKNOWN TunnelType = "unknown" ) -var Clients map[string]*Session - -func init() { - Clients = make(map[string]*Session) -} - func New(conn *ssh.ServerConn, sshChannel <-chan ssh.NewChannel, req <-chan *ssh.Request) *Session { session := &Session{ - ID: uuid.New(), Status: SETUP, Slug: "", ConnChannels: []ssh.Channel{}, @@ -65,22 +33,3 @@ func New(conn *ssh.ServerConn, sshChannel <-chan ssh.NewChannel, req <-chan *ssh return session } - -func (session *Session) Close() { - session.Done <- true - if session.TunnelType != HTTP { - session.Listener.Close() - } else { - delete(Clients, session.Slug) - } - - for _, ch := range session.ConnChannels { - if err := ch.Close(); err != nil { - fmt.Println("Error closing channel : ", err.Error()) - continue - } - } - if err := session.Connection.Close(); err != nil { - fmt.Println("Error closing connection : ", err.Error()) - } -}