From 82eb7af7a61aabbcdd066eb1080cc1328bdb32f2 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 7 Feb 2025 03:26:01 +0700 Subject: [PATCH] feat: add subdomain forwarding support for tunnel --- http/http.go | 113 +++++++++++++++++++++++---------------------- session/handler.go | 86 +++++++++++++++++++++++++++------- 2 files changed, 127 insertions(+), 72 deletions(-) diff --git a/http/http.go b/http/http.go index fd7fc52..cf4dee1 100644 --- a/http/http.go +++ b/http/http.go @@ -1,72 +1,73 @@ package httpServer import ( + "bufio" "bytes" "fmt" + "io" "log" - "net" + "net/http" "strings" "tunnel_pls/session" ) -//func ExtractDomain(conn net.Conn) (string, error) { -// defer conn.SetReadDeadline(time.Time{}) // Reset timeout after reading -// conn.SetReadDeadline(time.Now().Add(2 * time.Second)) // Prevent hanging -// -// reader := bufio.NewReader(conn) -// for { -// line, err := reader.ReadString('\n') -// if err != nil { -// return "", err -// } -// -// line = strings.TrimSpace(line) -// if strings.HasPrefix(strings.ToLower(line), "host:") { -// return strings.TrimSpace(strings.SplitN(line, ":", 2)[1]), nil -// } -// -// if line == "" { -// break -// } -// } -// -// return "", fmt.Errorf("host header not found") -//} - -func handleConnection(conn net.Conn) { - defer conn.Close() - - sshSession := session.Clients["test"] - sshSession.HandleForwardedConnection(conn, sshSession.Connection, 80) -} - -func getHost(data []byte) string { - lines := bytes.Split(data, []byte("\n")) - for _, line := range lines { - fmt.Println("here") - if bytes.HasPrefix(line, []byte("Host: ")) { - return strings.TrimSpace(string(line[6:])) - } - } - return "" -} - func Listen() { - listen, err := net.Listen("tcp", ":80") - if err != nil { - log.Fatal("Error starting server:", err) + server := http.Server{ + Addr: ":80", } - defer listen.Close() - fmt.Println("Server listening on port 80") - - for { - conn, err := listen.Accept() - if err != nil { - log.Println("Error accepting connection:", err) - continue + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + var rawRequest string + slug := strings.Split(r.Host, ".")[0] + if slug == "" { + http.Error(w, "You fuck up man", http.StatusBadRequest) + return } + sshSession, ok := session.Clients[slug] + if !ok { + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + rawRequest += fmt.Sprintf("%s %s %s\r\n", r.Method, r.URL.RequestURI(), r.Proto) + rawRequest += fmt.Sprintf("Host: %s\r\n", r.Host) + for k, v := range r.Header { + rawRequest += fmt.Sprintf("%s: %s\r\n", k, v[0]) + } + rawRequest += "\r\n" - go handleConnection(conn) - } + if r.Body != nil { + body, err := io.ReadAll(r.Body) + if err != nil { + log.Println("Error reading request body:", err) + } else { + rawRequest += string(body) + } + } + payload := []byte(rawRequest) + + host, originPort := session.ParseAddr(r.RemoteAddr) + data := sshSession.GetForwardedConnection(host, sshSession.Connection, payload, originPort, 80) + + response, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), r) + if err != nil { + return + } + var isServerSet = false + for k, v := range response.Header { + if k == "Server" { + isServerSet = true + w.Header().Set(k, fmt.Sprintf("Tunnel_Pls/%v", response.Header[k][0])) + continue + } + w.Header().Set(k, v[0]) + } + if !isServerSet { + w.Header().Set("Server", "Tunnel_Pls") + } + w.WriteHeader(response.StatusCode) + io.Copy(w, response.Body) + }) + + fmt.Println("Listening on port 80") + server.ListenAndServe() } diff --git a/session/handler.go b/session/handler.go index d02ba90..cd93eeb 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,6 +3,7 @@ package session import ( "bufio" "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -189,7 +190,7 @@ func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) { case "shell", "pty-req", "window-change": req.Reply(true, nil) default: - fmt.Println("Unknown request type") + fmt.Println("Unknown request type of : ", req.Type) req.Reply(false, nil) } } @@ -199,8 +200,8 @@ func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) { func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerConn, port uint32) { defer conn.Close() log.Printf("Handling new forwarded connection from %s", conn.RemoteAddr()) - - payload := createForwardedTCPIPPayload(conn, port) + host, originPort := ParseAddr(conn.RemoteAddr().String()) + payload := createForwardedTCPIPPayload(host, originPort, port) channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) @@ -236,14 +237,6 @@ func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerCo s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr\r\n")) return } else { - //if isHttp { - // response, err := http.ReadResponse(reader, nil) - // if err != nil { - // return - // } - // fmt.Println(response) - //} - io.Copy(conn, reader) } @@ -254,24 +247,85 @@ func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerCo }() } +func (s *Session) GetForwardedConnection(host string, sshConn *ssh.ServerConn, payload []byte, originPort, port uint32) []byte { + fmt.Println("Here 1") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + channelPayload := createForwardedTCPIPPayload(host, originPort, port) + channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", channelPayload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return nil + } + fmt.Println("Here 2") + + defer channel.Close() + + head := bytes.NewReader(payload) + go io.Copy(channel, head) + fmt.Println("Here 3") + + go func() { + for req := range reqs { + req.Reply(false, nil) + } + }() + fmt.Println("Here 4") + + var data bytes.Buffer + done := make(chan error, 1) + go func() { + io.Copy(&data, channel) + done <- err + }() + go func() { + var lastSize int + ticker := time.NewTicker(100) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + currentSize := data.Len() + fmt.Println("Size buffer:", currentSize) + + if currentSize == lastSize && currentSize > 0 { + fmt.Println("Buffer size unchanged, closing channel...") + cancel() + return + } + lastSize = currentSize + } + } + }() + select { + case <-ctx.Done(): + return data.Bytes() + case err = <-done: + return data.Bytes() + } +} + func writeSSHString(buffer *bytes.Buffer, str string) { binary.Write(buffer, binary.BigEndian, uint32(len(str))) buffer.WriteString(str) } -func parseAddr(addr string) (string, int) { +func ParseAddr(addr string) (string, uint32) { host, portStr, err := net.SplitHostPort(addr) if err != nil { log.Println("Failed to parse origin address:", err) - return "0.0.0.0", 0 + return "0.0.0.0", uint32(0) } port, _ := strconv.Atoi(portStr) - return host, port + return host, uint32(port) } -func createForwardedTCPIPPayload(conn net.Conn, port uint32) []byte { +func createForwardedTCPIPPayload(host string, originPort, port uint32) []byte { var buf bytes.Buffer - host, originPort := parseAddr(conn.RemoteAddr().String()) writeSSHString(&buf, "localhost") binary.Write(&buf, binary.BigEndian, uint32(port))