From 96dcea1f2ca36da6ee8d07b509492769d64dc7bc Mon Sep 17 00:00:00 2001 From: bagas Date: Sat, 6 Sep 2025 17:17:43 +0700 Subject: [PATCH] feat: add dedicated WebSocket service for subdomain ping --- go.mod | 5 ++- go.sum | 2 ++ server/http.go | 80 ++++++++++++++++++++++++++++++++++++++++++++++ server/https.go | 37 +++++++++++++++++++++ session/handler.go | 23 +++++++++++-- 5 files changed, 144 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 0d8b7b6..31fdc54 100644 --- a/go.mod +++ b/go.mod @@ -9,4 +9,7 @@ require ( golang.org/x/net v0.33.0 ) -require golang.org/x/sys v0.29.0 // indirect +require ( + github.com/gorilla/websocket v1.5.3 // indirect + golang.org/x/sys v0.29.0 // indirect +) diff --git a/go.sum b/go.sum index 6f93cb4..e14b727 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU= github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= diff --git a/server/http.go b/server/http.go index 36b17e1..8ea89b4 100644 --- a/server/http.go +++ b/server/http.go @@ -7,11 +7,57 @@ import ( "fmt" "log" "net" + "net/http" "strings" "tunnel_pls/session" "tunnel_pls/utils" + + "github.com/gorilla/websocket" ) +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +type connResponseWriter struct { + conn net.Conn + header http.Header + wrote bool +} + +func (w *connResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *connResponseWriter) WriteHeader(statusCode int) { + if w.wrote { + return + } + w.wrote = true + fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) + w.header.Write(w.conn) + fmt.Fprint(w.conn, "\r\n") +} + +func (w *connResponseWriter) Write(b []byte) (int, error) { + if !w.wrote { + w.WriteHeader(http.StatusOK) + } + return w.conn.Write(b) +} + +func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := bufio.NewReadWriter( + bufio.NewReader(w.conn), + bufio.NewWriter(w.conn), + ) + return w.conn, rw, nil +} + var redirectTLS = false func NewHTTPServer() error { @@ -72,6 +118,40 @@ func Handler(conn net.Conn) { return } + if slug == "ping" { + req, err := http.ReadRequest(reader) + if err != nil { + log.Println("failed to parse HTTP request:", err) + return + } + rw := &connResponseWriter{conn: conn} + + wsConn, err := upgrader.Upgrade(rw, req, nil) + if err != nil { + if !strings.Contains(err.Error(), "the client is not using the websocket protocol") { + log.Println("Upgrade failed:", err) + } + err := conn.Close() + if err != nil { + log.Println("failed to close connection:", err) + return + } + return + } + + err = wsConn.WriteMessage(websocket.TextMessage, []byte("pong")) + if err != nil { + log.Println("failed to write pong:", err) + return + } + err = wsConn.Close() + if err != nil { + log.Println("websocket close failed :", err) + return + } + return + } + sshSession, ok := session.Clients[slug] if !ok { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) diff --git a/server/https.go b/server/https.go index 0b882b7..043e74a 100644 --- a/server/https.go +++ b/server/https.go @@ -6,9 +6,12 @@ import ( "errors" "log" "net" + "net/http" "strings" "tunnel_pls/session" "tunnel_pls/utils" + + "github.com/gorilla/websocket" ) func NewHTTPSServer() error { @@ -62,6 +65,40 @@ func HandlerTLS(conn net.Conn) { } slug := host[0] + if slug == "ping" { + req, err := http.ReadRequest(reader) + if err != nil { + log.Println("failed to parse HTTP request:", err) + return + } + rw := &connResponseWriter{conn: conn} + + wsConn, err := upgrader.Upgrade(rw, req, nil) + if err != nil { + if !strings.Contains(err.Error(), "the client is not using the websocket protocol") { + log.Println("Upgrade failed:", err) + } + err := conn.Close() + if err != nil { + log.Println("failed to close connection:", err) + return + } + return + } + + err = wsConn.WriteMessage(websocket.TextMessage, []byte("pong")) + if err != nil { + log.Println("failed to write pong:", err) + return + } + err = wsConn.Close() + if err != nil { + log.Println("websocket close failed :", err) + return + } + return + } + sshSession, ok := session.Clients[slug] if !ok { conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) diff --git a/session/handler.go b/session/handler.go index fb840d7..c1328a7 100644 --- a/session/handler.go +++ b/session/handler.go @@ -16,8 +16,9 @@ import ( "time" portUtil "tunnel_pls/internal/port" - "golang.org/x/crypto/ssh" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) type SessionStatus string @@ -28,6 +29,10 @@ const ( SETUP SessionStatus = "SETUP" ) +var forbiddenSlug = []string{ + "ping", +} + type UserConnection struct { Reader io.Reader Writer net.Conn @@ -409,12 +414,21 @@ func (s *Session) handleSlugEditMode(connection ssh.Channel, inSlugEditMode *boo } } +func isForbiddenSlug(slug string) bool { + for _, s := range forbiddenSlug { + if slug == s { + return true + } + } + return false +} + func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, editSlug *string, commandBuffer *bytes.Buffer) { isValid := isValidSlug(*editSlug) connection.Write([]byte("\033[H\033[2J")) - if isValid { + if !isValid { oldSlug := s.Slug newSlug := *editSlug @@ -426,6 +440,11 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n")) connection.Write([]byte("Press any key to continue...\r\n")) + } else if isForbiddenSlug(*editSlug) { + connection.Write([]byte("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n")) + connection.Write([]byte("This subdomain is not allowed.\r\n")) + connection.Write([]byte("Please try a different subdomain.\r\n\r\n")) + connection.Write([]byte("Press any key to continue...\r\n")) } else { connection.Write([]byte("\r\n\r\n❌ INVALID SUBDOMAIN ❌\r\n\r\n")) connection.Write([]byte("Use only lowercase letters, numbers, and hyphens.\r\n"))