diff --git a/server/header.go b/server/header.go index 326d617..0b36a2c 100644 --- a/server/header.go +++ b/server/header.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "fmt" - "io" "strings" ) @@ -28,8 +27,7 @@ type RequestHeaderFactory struct { headers map[string]string } -func NewRequestHeaderFactory(r io.Reader) (*RequestHeaderFactory, error) { - br := bufio.NewReader(r) +func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) { header := &RequestHeaderFactory{ headers: make(map[string]string), } diff --git a/server/http.go b/server/http.go index 8a0ba58..cc39d46 100644 --- a/server/http.go +++ b/server/http.go @@ -72,7 +72,7 @@ func (cw *CustomWriter) Read(p []byte) (int, error) { } for _, m := range cw.reqEndMW { - err := m.HandleRequest(cw.reqHeader) + err = m.HandleRequest(cw.reqHeader) if err != nil { log.Printf("Error when applying request middleware: %v", err) return 0, err @@ -212,7 +212,8 @@ func NewHTTPServer() error { } go func() { for { - conn, err := listener.Accept() + var conn net.Conn + conn, err = listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -257,7 +258,7 @@ func Handler(conn net.Conn) { slug := host[0] if redirectTLS { - _, err := conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + + _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("domain")) + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -270,8 +271,7 @@ func Handler(conn net.Conn) { } if slug == "ping" { - // TODO: implement cors - _, err := conn.Write([]byte( + _, err = conn.Write([]byte( "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -289,7 +289,7 @@ func Handler(conn net.Conn) { sshSession, ok := session.Clients[slug] if !ok { - _, err := conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + + _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -298,11 +298,6 @@ func Handler(conn net.Conn) { log.Println("Failed to write 301 Moved Permanently:", err) return } - err = conn.Close() - if err != nil { - log.Println("Failed to close connection:", err) - return - } return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) @@ -346,7 +341,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS cw.reqHeader = initialRequest for _, m := range cw.reqStartMW { - err := m.HandleRequest(cw.reqHeader) + err = m.HandleRequest(cw.reqHeader) if err != nil { log.Printf("Error handling request: %v", err) return diff --git a/server/https.go b/server/https.go index cbe7c86..4e23d17 100644 --- a/server/https.go +++ b/server/https.go @@ -26,7 +26,8 @@ func NewHTTPSServer() error { go func() { for { - conn, err := ln.Accept() + var conn net.Conn + conn, err = ln.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { log.Println("https server closed") @@ -60,24 +61,18 @@ func HandlerTLS(conn net.Conn) { host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { - _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) + _, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) if err != nil { log.Println("Failed to write 400 Bad Request:", err) return } - err = conn.Close() - if err != nil { - log.Println("Failed to close connection:", err) - return - } return } slug := host[0] if slug == "ping" { - // TODO: implement cors - _, err := conn.Write([]byte( + _, err = conn.Write([]byte( "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -95,7 +90,7 @@ func HandlerTLS(conn net.Conn) { sshSession, ok := session.Clients[slug] if !ok { - _, err := conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + + _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -104,11 +99,6 @@ func HandlerTLS(conn net.Conn) { log.Println("Failed to write 301 Moved Permanently:", err) return } - err = conn.Close() - if err != nil { - log.Println("Failed to close connection:", err) - return - } return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) diff --git a/server/server.go b/server/server.go index 75c9b89..8051a02 100644 --- a/server/server.go +++ b/server/server.go @@ -24,7 +24,7 @@ func NewServer(config *ssh.ServerConfig) *Server { } if utils.Getenv("tls_enabled") == "true" { go func() { - err := NewHTTPSServer() + err = NewHTTPSServer() if err != nil { log.Fatalf("failed to start https server: %v", err) } @@ -32,7 +32,7 @@ func NewServer(config *ssh.ServerConfig) *Server { }() } go func() { - err := NewHTTPServer() + err = NewHTTPServer() if err != nil { log.Fatalf("failed to start http server: %v", err) } diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 0c998c4..049608e 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -52,6 +52,7 @@ type Interaction struct { SlugManager slug.Manager Forwarder Forwarder Lifecycle Lifecycle + pendingExit bool updateClientSlug func(oldSlug, newSlug string) bool } @@ -94,6 +95,27 @@ func (i *Interaction) HandleUserInput() { continue } + if i.pendingExit { + if char != 3 { + i.pendingExit = false + i.SendMessage("Operation canceled.\r\n") + } + } + + if char == 3 { + if i.pendingExit { + i.SendMessage("Closing connection...\r\n") + err = i.Lifecycle.Close() + if err != nil { + log.Printf("failed to close session: %v", err) + return + } + return + } + i.SendMessage("Please press Ctrl+C again to disconnect.\r\n") + i.pendingExit = true + } + i.SendMessage(string(buf[:n])) if char == 8 || char == 127 { @@ -122,6 +144,11 @@ func (i *Interaction) HandleUserInput() { } i.CommandBuffer.WriteByte(char) } + + if char == 13 { + i.SendMessage("\033[K") + } + } } } @@ -129,7 +156,7 @@ func (i *Interaction) HandleUserInput() { func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) { if char == 13 { i.HandleSlugSave(connection) - } else if char == 27 { + } else if char == 27 || char == 3 { i.HandleSlugCancel(connection) } else if char == 8 || char == 127 { if len(i.EditSlug) > 0 { @@ -310,7 +337,7 @@ func (i *Interaction) HandleSlugUpdateError() { func (i *Interaction) HandleCommand(command string) { switch command { case "/bye": - i.SendMessage("\r\nClosing connection...") + i.SendMessage("Closing connection...\r\n") err := i.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) @@ -343,7 +370,7 @@ func (i *Interaction) HandleCommand(command string) { i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain")) } default: - i.SendMessage("Unknown command") + i.SendMessage("Unknown command\r\n") } i.CommandBuffer.Reset()