diff --git a/http/http.go b/http/http.go index cf4dee1..918f80a 100644 --- a/http/http.go +++ b/http/http.go @@ -2,72 +2,76 @@ package httpServer import ( "bufio" - "bytes" "fmt" "io" "log" + "net" "net/http" "strings" "tunnel_pls/session" ) func Listen() { - server := http.Server{ - Addr: ":80", + server, err := net.Listen("tcp", ":80") + if err != nil { + log.Fatal(err) + return + } + defer server.Close() + log.Println("Listening on :80") + for { + conn, err := server.Accept() + if err != nil { + log.Fatal(err) + return + } + + go handleRequest(conn) + } +} + +func handleRequest(conn net.Conn) { + defer conn.Close() + var rawRequest string + + reader := bufio.NewReader(conn) + r, err := http.ReadRequest(reader) + if err != nil { + fmt.Println("Error reading request:", err) + return } - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - var rawRequest string - slug := strings.Split(r.Host, ".")[0] - if slug == "" { - http.Error(w, "You fuck up man", http.StatusBadRequest) - return - } - sshSession, ok := session.Clients[slug] - if !ok { - http.Error(w, "Bad Gateway", http.StatusBadGateway) - return - } - rawRequest += fmt.Sprintf("%s %s %s\r\n", r.Method, r.URL.RequestURI(), r.Proto) - rawRequest += fmt.Sprintf("Host: %s\r\n", r.Host) - for k, v := range r.Header { - rawRequest += fmt.Sprintf("%s: %s\r\n", k, v[0]) - } - rawRequest += "\r\n" + slug := strings.Split(r.Host, ".")[0] + if slug == "" { + fmt.Println("Error parsing slug: ", r.Host) + return + } - if r.Body != nil { - body, err := io.ReadAll(r.Body) - if err != nil { - log.Println("Error reading request body:", err) - } else { - rawRequest += string(body) - } - } - payload := []byte(rawRequest) + sshSession, ok := session.Clients[slug] + if !ok { + fmt.Println("Error finding ssh session: ", slug) + return + } - host, originPort := session.ParseAddr(r.RemoteAddr) - data := sshSession.GetForwardedConnection(host, sshSession.Connection, payload, originPort, 80) + rawRequest += fmt.Sprintf("%s %s %s\r\n", r.Method, r.URL.RequestURI(), r.Proto) + rawRequest += fmt.Sprintf("Host: %s\r\n", r.Host) - response, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), r) + for k, v := range r.Header { + rawRequest += fmt.Sprintf("%s: %s\r\n", k, v[0]) + } + rawRequest += "\r\n" + + if r.Body != nil { + body, err := io.ReadAll(r.Body) if err != nil { - return + log.Println("Error reading request body:", err) + } else { + rawRequest += string(body) } - var isServerSet = false - for k, v := range response.Header { - if k == "Server" { - isServerSet = true - w.Header().Set(k, fmt.Sprintf("Tunnel_Pls/%v", response.Header[k][0])) - continue - } - w.Header().Set(k, v[0]) - } - if !isServerSet { - w.Header().Set("Server", "Tunnel_Pls") - } - w.WriteHeader(response.StatusCode) - io.Copy(w, response.Body) - }) + } - fmt.Println("Listening on port 80") - server.ListenAndServe() + payload := []byte(rawRequest) + + host, originPort := session.ParseAddr(conn.RemoteAddr().String()) + sshSession.GetForwardedConnection(conn, host, sshSession.Connection, payload, originPort, 80) } diff --git a/main b/main new file mode 100644 index 0000000..64518e5 Binary files /dev/null and b/main differ diff --git a/session/handler.go b/session/handler.go index cd93eeb..f9eb582 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,7 +3,6 @@ package session import ( "bufio" "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -14,6 +13,7 @@ import ( "strconv" "time" "tunnel_pls/proto" + "tunnel_pls/utils" ) func (s *Session) handleGlobalRequest() { @@ -45,12 +45,19 @@ func (s *Session) handleGlobalRequest() { if portToBind == 80 || portToBind == 443 { s.TunnelType = HTTP - Clients["test"] = s - // TODO: dont forward traffic to the listener below + var slug string + for { + slug = utils.GenerateRandomString(32) + if _, ok := Clients[slug]; ok { + continue + } + break + } + Clients[slug] = s buf := new(bytes.Buffer) binary.Write(buf, binary.BigEndian, uint32(portToBind)) - log.Printf("Forwarding approved on port: %d", portToBind) + s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.tunnl.live", slug))) req.Reply(true, buf.Bytes()) } else { s.TunnelType = TCP @@ -73,7 +80,7 @@ func (s *Session) handleGlobalRequest() { log.Printf("Error accepting connection: %v", err) continue } - fmt.Println("ini bind : ", portToBind) + go s.HandleForwardedConnection(conn, s.Connection, portToBind) } }() @@ -184,6 +191,7 @@ func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) { } connection.Write([]byte("\r\n\r\n")) + go s.handleGlobalRequest() for req := range requests { switch req.Type { @@ -247,66 +255,35 @@ func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerCo }() } -func (s *Session) GetForwardedConnection(host string, sshConn *ssh.ServerConn, payload []byte, originPort, port uint32) []byte { - fmt.Println("Here 1") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - +func (s *Session) GetForwardedConnection(conn net.Conn, host string, sshConn *ssh.ServerConn, payload []byte, originPort, port uint32) { + defer conn.Close() channelPayload := createForwardedTCPIPPayload(host, originPort, port) channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", channelPayload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) - return nil + return } - fmt.Println("Here 2") - defer channel.Close() - head := bytes.NewReader(payload) - go io.Copy(channel, head) - fmt.Println("Here 3") + connReader := bufio.NewReader(conn) + initalPayload := bytes.NewReader(payload) + io.Copy(channel, initalPayload) + go io.Copy(channel, connReader) + reader := bufio.NewReader(channel) + _, err = reader.Peek(1) + if err == io.EOF { + io.Copy(conn, bytes.NewReader([]byte("HTTP/1.1 502 Bad Gateway\r\nContent-Length: 11\r\nContent-Type: text/plain\r\n\r\nBad Gateway"))) + s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr\r\n")) + return + } else { + io.Copy(conn, reader) + } go func() { for req := range reqs { req.Reply(false, nil) } }() - fmt.Println("Here 4") - - var data bytes.Buffer - done := make(chan error, 1) - go func() { - io.Copy(&data, channel) - done <- err - }() - go func() { - var lastSize int - ticker := time.NewTicker(100) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - currentSize := data.Len() - fmt.Println("Size buffer:", currentSize) - - if currentSize == lastSize && currentSize > 0 { - fmt.Println("Buffer size unchanged, closing channel...") - cancel() - return - } - lastSize = currentSize - } - } - }() - select { - case <-ctx.Done(): - return data.Bytes() - case err = <-done: - return data.Bytes() - } } func writeSSHString(buffer *bytes.Buffer, str string) { @@ -317,7 +294,7 @@ func writeSSHString(buffer *bytes.Buffer, str string) { func ParseAddr(addr string) (string, uint32) { host, portStr, err := net.SplitHostPort(addr) if err != nil { - log.Println("Failed to parse origin address:", err) + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) return "0.0.0.0", uint32(0) } port, _ := strconv.Atoi(portStr) diff --git a/session/session.go b/session/session.go index e235266..420dbc7 100644 --- a/session/session.go +++ b/session/session.go @@ -39,8 +39,6 @@ func New(conn *ssh.ServerConn, sshChannel <-chan ssh.NewChannel, req <-chan *ssh Done: make(chan bool), } - go session.handleGlobalRequest() - go func() { for newChannel := range sshChannel { go session.HandleSessionChannel(newChannel) @@ -65,5 +63,4 @@ func (session *Session) Close() { if err := session.Connection.Close(); err != nil { fmt.Println("Error closing connection : ", err.Error()) } - } diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..da9a779 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,18 @@ +package utils + +import ( + "math/rand" + "strings" + "time" +) + +func GenerateRandomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyz" + seededRand := rand.New(rand.NewSource(time.Now().UnixNano() + int64(rand.Intn(9999)))) + var result strings.Builder + for i := 0; i < length; i++ { + randomIndex := seededRand.Intn(len(charset)) + result.WriteString(string(charset[randomIndex])) + } + return result.String() +}