From 8a1604fde84421d68ae0c8b3bc9cf5b82fea79da Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 6 Feb 2025 22:14:13 +0700 Subject: [PATCH] refactor: separate core components and improve session & server handling --- go.mod | 7 +- go.sum | 6 +- http/http.go | 72 +++++++++++ main.go | 281 +------------------------------------------ proto/proto.go | 50 ++++++++ server/handler.go | 22 ++++ server/server.go | 42 +++++++ session/handler.go | 294 +++++++++++++++++++++++++++++++++++++++++++++ session/session.go | 69 +++++++++++ 9 files changed, 556 insertions(+), 287 deletions(-) create mode 100644 http/http.go create mode 100644 proto/proto.go create mode 100644 server/handler.go create mode 100644 server/server.go create mode 100644 session/handler.go create mode 100644 session/session.go diff --git a/go.mod b/go.mod index 23f0956..938e209 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,8 @@ module tunnel_pls go 1.23 require ( - github.com/kr/pty v1.1.8 + github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be golang.org/x/crypto v0.32.0 ) -require ( - github.com/creack/pty v1.1.7 // indirect - golang.org/x/sys v0.29.0 // indirect -) +require golang.org/x/sys v0.29.0 // indirect diff --git a/go.sum b/go.sum index ae3014c..7270b24 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ -github.com/creack/pty v1.1.7 h1:6pwm8kMQKCmgUg0ZHTm5+/YvRK0s3THD/28+T6/kk4A= -github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/kr/pty v1.1.8 h1:AkaSdXYQOWeaO3neb8EM634ahkXXe3jYbVh/F9lq+GI= -github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be h1:J5BL2kskAlV9ckgEsNQXscjIaLiOYiZ75d4e94E6dcQ= +github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod h1:mk5IQ+Y0ZeO87b858TlA645sVcEcbiX6YqP98kt+7+w= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= diff --git a/http/http.go b/http/http.go new file mode 100644 index 0000000..fd7fc52 --- /dev/null +++ b/http/http.go @@ -0,0 +1,72 @@ +package httpServer + +import ( + "bytes" + "fmt" + "log" + "net" + "strings" + "tunnel_pls/session" +) + +//func ExtractDomain(conn net.Conn) (string, error) { +// defer conn.SetReadDeadline(time.Time{}) // Reset timeout after reading +// conn.SetReadDeadline(time.Now().Add(2 * time.Second)) // Prevent hanging +// +// reader := bufio.NewReader(conn) +// for { +// line, err := reader.ReadString('\n') +// if err != nil { +// return "", err +// } +// +// line = strings.TrimSpace(line) +// if strings.HasPrefix(strings.ToLower(line), "host:") { +// return strings.TrimSpace(strings.SplitN(line, ":", 2)[1]), nil +// } +// +// if line == "" { +// break +// } +// } +// +// return "", fmt.Errorf("host header not found") +//} + +func handleConnection(conn net.Conn) { + defer conn.Close() + + sshSession := session.Clients["test"] + sshSession.HandleForwardedConnection(conn, sshSession.Connection, 80) +} + +func getHost(data []byte) string { + lines := bytes.Split(data, []byte("\n")) + for _, line := range lines { + fmt.Println("here") + if bytes.HasPrefix(line, []byte("Host: ")) { + return strings.TrimSpace(string(line[6:])) + } + } + return "" +} + +func Listen() { + listen, err := net.Listen("tcp", ":80") + if err != nil { + log.Fatal("Error starting server:", err) + } + defer listen.Close() + + fmt.Println("Server listening on port 80") + + for { + conn, err := listen.Accept() + if err != nil { + log.Println("Error accepting connection:", err) + continue + } + + go handleConnection(conn) + } +} diff --git a/main.go b/main.go index 90eb6da..3c5d5c7 100644 --- a/main.go +++ b/main.go @@ -1,15 +1,10 @@ package main import ( - "bytes" - "encoding/binary" - "fmt" "golang.org/x/crypto/ssh" - "io" "log" - "net" "os" - "strconv" + "tunnel_pls/server" ) func main() { @@ -31,276 +26,6 @@ func main() { } sshConfig.AddHostKey(private) - listen, err := net.Listen("tcp", ":2200") - if err != nil { - log.Fatal(err) - return - } - log.Println("Listening on port 2200") - - for { - tcpConn, err := listen.Accept() - if err != nil { - log.Fatal(err) - return - } - sshConn, connChan, globalConnChan, err := ssh.NewServerConn(tcpConn, sshConfig) - if err != nil { - log.Printf("Failed to handshake (%s)", err) - continue - } - log.Printf("New SSH connection from %s (%s) with User (%s)", sshConn.RemoteAddr(), sshConn.ClientVersion(), sshConn.User()) - - go handleRequests(globalConnChan, sshConn) - go handleChannels(connChan, sshConn) - } -} - -func handleChannels(chans <-chan ssh.NewChannel, sshConn *ssh.ServerConn) { - for newChannel := range chans { - go handleChannel(newChannel, sshConn) - } -} - -func handleRequests(reqs <-chan *ssh.Request, sshConn *ssh.ServerConn) { - for req := range reqs { - log.Printf("Received global request: %s", req.Type) - - if req.Type == "tcpip-forward" { - log.Println("Port forwarding request detected") - - reader := bytes.NewReader(req.Payload) - - addr, err := readSSHString(reader) - if err != nil { - log.Println("Failed to read address from payload:", err) - req.Reply(false, nil) - continue - } - - var portToBind uint32 - if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil { - log.Println("Failed to read port from payload:", err) - req.Reply(false, nil) - continue - } - - log.Printf("Requested forwarding on %s:%d", addr, portToBind) - - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) - if err != nil { - log.Printf("Failed to bind to port %d: %v", portToBind, err) - req.Reply(false, nil) - continue - } - - go func() { - for { - conn, err := listener.Accept() - if err != nil { - log.Printf("Error accepting connection: %v", err) - continue - } - go handleForwardedConnection(conn, sshConn, portToBind) - } - }() - - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint32(portToBind)) - - log.Printf("Forwarding approved on port: %d", portToBind) - req.Reply(true, buf.Bytes()) - } else { - req.Reply(false, nil) - } - } -} - -func handleForwardedConnection(conn net.Conn, sshConn *ssh.ServerConn, port uint32) { - defer conn.Close() - log.Printf("Handling new forwarded connection from %s", conn.RemoteAddr()) - - payload := createForwardedTCPIPPayload(conn, port) - channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - return - } - defer channel.Close() - - go io.Copy(channel, conn) - io.Copy(conn, channel) - - go func() { - for req := range reqs { - req.Reply(false, nil) - } - }() -} - -func handleChannel(newChannel ssh.NewChannel, sshConn *ssh.ServerConn) { - switch newChannel.ChannelType() { - case "session": - handleSessionChannel(newChannel) - - case "forwarded-tcpip": - //handleForwardedTCPIP(newChannel) - default: - newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type") - } -} - -func handleSessionChannel(newChannel ssh.NewChannel) { - connection, requests, err := newChannel.Accept() - if err != nil { - log.Printf("Could not accept channel: %s", err) - return - } - var bandwidth uint32 - go func() { - var commandBuffer bytes.Buffer - buf := make([]byte, 1) - for { - n, err := connection.Read(buf) - bandwidth += uint32(n) - fmt.Println("using ", bandwidth) - if n > 0 { - char := buf[0] - connection.Write(buf[:n]) - if char == 8 || char == 127 { - if commandBuffer.Len() > 0 { - commandBuffer.Truncate(commandBuffer.Len() - 1) - connection.Write([]byte("\b \b")) - } - continue - } - - if char == '/' { - commandBuffer.Reset() - commandBuffer.WriteByte(char) - continue - } - - if commandBuffer.Len() > 0 { - if char == 13 { - command := commandBuffer.String() - fmt.Println("User entered command:", command) - - if command == "/bye" { - fmt.Println("Closing connection...") - connection.Close() - return - } else if command == "/help" { - connection.Write([]byte("Available commands: /bye, /help")) - } else { - connection.Write([]byte("Unknown command")) - } - - commandBuffer.Reset() - continue - } - - commandBuffer.WriteByte(char) - continue - } - } - - if err != nil { - if err != io.EOF { - log.Printf("Error reading from client: %s", err) - } - break - } - } - }() - - go func() { - connection.Write([]byte("hello world")) - for req := range requests { - switch req.Type { - case "shell", "pty-req": - req.Reply(true, nil) - default: - fmt.Println("Unknown request type") - req.Reply(false, nil) - } - } - }() -} - -//func handleForwardedTCPIP(newChannel ssh.NewChannel) { -// reader := bytes.NewReader(newChannel.ExtraData()) -// -// destAddr, err := readSSHString(reader) -// if err != nil { -// log.Println("Failed to read destination address:", err) -// newChannel.Reject(ssh.ConnectionFailed, "invalid destination") -// return -// } -// -// var destPort uint32 -// if err := binary.Read(reader, binary.BigEndian, &destPort); err != nil { -// log.Println("Failed to read destination port:", err) -// newChannel.Reject(ssh.ConnectionFailed, "invalid port") -// return -// } -// -// log.Printf("Forwarding connection to %s:%d", destAddr, destPort) -// -// targetConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", destAddr, destPort)) -// fmt.Println("connected to", destAddr) -// if err != nil { -// log.Printf("Failed to connect to %s:%d: %v", destAddr, destPort, err) -// newChannel.Reject(ssh.ConnectionFailed, "could not connect to target") -// return -// } -// -// channel, _, err := newChannel.Accept() -// if err != nil { -// log.Printf("Could not accept forwarded channel: %v", err) -// targetConn.Close() -// return -// } -// -// go io.Copy(channel, targetConn) -// go io.Copy(targetConn, channel) -//} - -func writeSSHString(buffer *bytes.Buffer, str string) { - binary.Write(buffer, binary.BigEndian, uint32(len(str))) - buffer.WriteString(str) -} - -func parseAddr(addr string) (string, int) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Println("Failed to parse origin address:", err) - return "0.0.0.0", 0 - } - port, _ := strconv.Atoi(portStr) - return host, port -} - -func createForwardedTCPIPPayload(conn net.Conn, port uint32) []byte { - var buf bytes.Buffer - host, originPort := parseAddr(conn.RemoteAddr().String()) - - writeSSHString(&buf, "localhost") - binary.Write(&buf, binary.BigEndian, uint32(port)) - writeSSHString(&buf, host) - binary.Write(&buf, binary.BigEndian, uint32(originPort)) - - return buf.Bytes() -} - -func readSSHString(reader *bytes.Reader) (string, error) { - var length uint32 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - return "", err - } - strBytes := make([]byte, length) - if _, err := reader.Read(strBytes); err != nil { - return "", err - } - return string(strBytes), nil + app := server.NewServer(*sshConfig) + app.Start() } diff --git a/proto/proto.go b/proto/proto.go new file mode 100644 index 0000000..5290e62 --- /dev/null +++ b/proto/proto.go @@ -0,0 +1,50 @@ +/* +Package proto provides byte-level interaction with HTTP request payload. + +Example of HTTP payload for future references, new line symbols escaped: + + POST /upload HTTP/1.1\r\n + User-Agent: Gor\r\n + Content-Length: 11\r\n + \r\n + Hello world + + GET /index.html HTTP/1.1\r\n + User-Agent: Gor\r\n + \r\n + \r\n + +https://github.com/buger/goreplay/blob/master/proto/proto.go +*/ +package proto + +import ( + "bytes" + "net/http" +) + +var Methods = [...]string{ + http.MethodConnect, http.MethodDelete, http.MethodGet, + http.MethodHead, http.MethodOptions, http.MethodPatch, + http.MethodPost, http.MethodPut, http.MethodTrace, +} + +func Method(payload []byte) []byte { + end := bytes.IndexByte(payload, ' ') + if end == -1 { + return nil + } + + return payload[:end] +} + +func IsHttpRequest(payload []byte) bool { + method := string(Method(payload)) + var methodFound bool + for _, m := range Methods { + if methodFound = method == m; methodFound { + break + } + } + return methodFound +} diff --git a/server/handler.go b/server/handler.go new file mode 100644 index 0000000..51ca861 --- /dev/null +++ b/server/handler.go @@ -0,0 +1,22 @@ +package server + +import ( + "fmt" + "golang.org/x/crypto/ssh" + "log" + "net" + "tunnel_pls/session" +) + +func (s *Server) handleConnection(conn net.Conn) { + sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config) + if err != nil { + log.Printf("failed to establish SSH connection: %v", err) + conn.Close() + return + } + + fmt.Println("SSH connection established:", sshConn.User()) + + session.New(sshConn, chans, reqs) +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..1c4e546 --- /dev/null +++ b/server/server.go @@ -0,0 +1,42 @@ +package server + +import ( + "fmt" + "golang.org/x/crypto/ssh" + "log" + "net" + "net/http" + httpServer "tunnel_pls/http" +) + +type Server struct { + Conn *net.Listener + Config *ssh.ServerConfig + HttpServer *http.Server +} + +func NewServer(config ssh.ServerConfig) *Server { + listener, err := net.Listen("tcp", ":2200") + if err != nil { + log.Fatalf("failed to listen on port 2200: %v", err) + return nil + } + go httpServer.Listen() + return &Server{ + Conn: &listener, + Config: &config, + } +} + +func (s *Server) Start() { + fmt.Println("SSH server is starting on port 2200...") + for { + conn, err := (*s.Conn).Accept() + if err != nil { + log.Printf("failed to accept connection: %v", err) + continue + } + + go s.handleConnection(conn) + } +} diff --git a/session/handler.go b/session/handler.go new file mode 100644 index 0000000..d02ba90 --- /dev/null +++ b/session/handler.go @@ -0,0 +1,294 @@ +package session + +import ( + "bufio" + "bytes" + "encoding/binary" + "errors" + "fmt" + "golang.org/x/crypto/ssh" + "io" + "log" + "net" + "strconv" + "time" + "tunnel_pls/proto" +) + +func (s *Session) handleGlobalRequest() { + for { + select { + case req := <-s.GlobalRequest: + if req == nil { + return + } + if req.Type == "tcpip-forward" { + log.Println("Port forwarding request detected") + + reader := bytes.NewReader(req.Payload) + + addr, err := readSSHString(reader) + if err != nil { + log.Println("Failed to read address from payload:", err) + req.Reply(false, nil) + continue + } + + var portToBind uint32 + + if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil { + log.Println("Failed to read port from payload:", err) + req.Reply(false, nil) + continue + } + + if portToBind == 80 || portToBind == 443 { + s.TunnelType = HTTP + Clients["test"] = s + // TODO: dont forward traffic to the listener below + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint32(portToBind)) + + log.Printf("Forwarding approved on port: %d", portToBind) + req.Reply(true, buf.Bytes()) + } else { + s.TunnelType = TCP + log.Printf("Requested forwarding on %s:%d", addr, portToBind) + + listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) + if err != nil { + log.Printf("Failed to bind to port %d: %v", portToBind, err) + req.Reply(false, nil) + continue + } + s.Listener = listener + go func() { + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + log.Printf("Error accepting connection: %v", err) + continue + } + fmt.Println("ini bind : ", portToBind) + go s.HandleForwardedConnection(conn, s.Connection, portToBind) + } + }() + + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint32(portToBind)) + + log.Printf("Forwarding approved on port: %d", portToBind) + req.Reply(true, buf.Bytes()) + } + + } else { + req.Reply(false, nil) + } + case <-s.Done: + break + } + } +} + +func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) { + connection, requests, err := newChannel.Accept() + s.ConnChannels = append(s.ConnChannels, connection) + if err != nil { + log.Printf("Could not accept channel: %s", err) + return + } + go func() { + var commandBuffer bytes.Buffer + buf := make([]byte, 1) + for { + n, err := connection.Read(buf) + if n > 0 { + char := buf[0] + connection.Write(buf[:n]) + if char == 8 || char == 127 { + if commandBuffer.Len() > 0 { + commandBuffer.Truncate(commandBuffer.Len() - 1) + connection.Write([]byte("\b \b")) + } + continue + } + + if char == '/' { + commandBuffer.Reset() + commandBuffer.WriteByte(char) + continue + } + + if commandBuffer.Len() > 0 { + if char == 13 { + command := commandBuffer.String() + fmt.Println("User entered command:", command, "<>") + + if command == "/bye" { + fmt.Println("Closing connection...") + s.Close() + break + } else if command == "/help" { + connection.Write([]byte("Available commands: /bye, /help, /clear")) + + } else if command == "/clear" { + connection.Write([]byte("\033[H\033[2J")) + } else { + connection.Write([]byte("Unknown command")) + } + + commandBuffer.Reset() + continue + } + + commandBuffer.WriteByte(char) + continue + } + } + + if err != nil { + if err != io.EOF { + log.Printf("Error reading from client: %s", err) + } + break + } + } + }() + + go func() { + asciiArt := []string{ + ` _______ _ _____ _ `, + `|__ __| | | | __ \| | `, + ` | |_ _ _ __ _ __ ___| | | |__) | |___ `, + ` | | | | | '_ \| '_ \ / _ \ | | ___/| / __|`, + ` | | |_| | | | | | | | __/ | | | | \__ \`, + ` |_|\__,_|_| |_|_| |_|\___|_| |_| |_|___/`, + ``, + ` "Tunnel Pls" - Project by Bagas`, + ` https://fossy.my.id`, + ``, + ` Welcome to Tunnel! Available commands:`, + ` - '/bye' : Exit the tunnel`, + ` - '/help' : Show this help message`, + ` - '/clear' : Clear the current line`, + } + + connection.Write([]byte("\033[H\033[2J")) + + for _, line := range asciiArt { + connection.Write([]byte("\r\n" + line)) + } + + connection.Write([]byte("\r\n\r\n")) + + for req := range requests { + switch req.Type { + case "shell", "pty-req", "window-change": + req.Reply(true, nil) + default: + fmt.Println("Unknown request type") + req.Reply(false, nil) + } + } + }() +} + +func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerConn, port uint32) { + defer conn.Close() + log.Printf("Handling new forwarded connection from %s", conn.RemoteAddr()) + + payload := createForwardedTCPIPPayload(conn, port) + channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return + } + defer channel.Close() + + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + + connReader := bufio.NewReader(conn) + + var isHttp bool + header, err := connReader.Peek(7) + + if err != nil { + isHttp = false + } else { + isHttp = proto.IsHttpRequest(header) + } + + conn.SetReadDeadline(time.Time{}) + + go io.Copy(channel, connReader) + + reader := bufio.NewReader(channel) + _, err = reader.Peek(1) + if err == io.EOF { + if isHttp { + io.Copy(conn, bytes.NewReader([]byte("HTTP/1.1 502 Bad Gateway\r\nContent-Length: 11\r\nContent-Type: text/plain\r\n\r\nBad Gateway"))) + } else { + conn.Write([]byte("Could not forward request to the tunnel addr\r\n")) + } + s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr\r\n")) + return + } else { + //if isHttp { + // response, err := http.ReadResponse(reader, nil) + // if err != nil { + // return + // } + // fmt.Println(response) + //} + + io.Copy(conn, reader) + } + + go func() { + for req := range reqs { + req.Reply(false, nil) + } + }() +} + +func writeSSHString(buffer *bytes.Buffer, str string) { + binary.Write(buffer, binary.BigEndian, uint32(len(str))) + buffer.WriteString(str) +} + +func parseAddr(addr string) (string, int) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Println("Failed to parse origin address:", err) + return "0.0.0.0", 0 + } + port, _ := strconv.Atoi(portStr) + return host, port +} + +func createForwardedTCPIPPayload(conn net.Conn, port uint32) []byte { + var buf bytes.Buffer + host, originPort := parseAddr(conn.RemoteAddr().String()) + + writeSSHString(&buf, "localhost") + binary.Write(&buf, binary.BigEndian, uint32(port)) + writeSSHString(&buf, host) + binary.Write(&buf, binary.BigEndian, uint32(originPort)) + + return buf.Bytes() +} + +func readSSHString(reader *bytes.Reader) (string, error) { + var length uint32 + if err := binary.Read(reader, binary.BigEndian, &length); err != nil { + return "", err + } + strBytes := make([]byte, length) + if _, err := reader.Read(strBytes); err != nil { + return "", err + } + return string(strBytes), nil +} diff --git a/session/session.go b/session/session.go new file mode 100644 index 0000000..e235266 --- /dev/null +++ b/session/session.go @@ -0,0 +1,69 @@ +package session + +import ( + "fmt" + "golang.org/x/crypto/ssh" + "net" +) + +type Session struct { + ConnChannels []ssh.Channel + Connection *ssh.ServerConn + GlobalRequest <-chan *ssh.Request + Listener net.Listener + TunnelType TunnelType + Done chan bool +} + +type TunnelType string + +const ( + HTTP TunnelType = "http" + TCP TunnelType = "tcp" + UDP TunnelType = "udp" + UNKNOWN TunnelType = "unknown" +) + +var Clients map[string]*Session + +func init() { + Clients = make(map[string]*Session) +} + +func New(conn *ssh.ServerConn, sshChannel <-chan ssh.NewChannel, req <-chan *ssh.Request) *Session { + session := &Session{ + ConnChannels: []ssh.Channel{}, + Connection: conn, + GlobalRequest: req, + TunnelType: UNKNOWN, + Done: make(chan bool), + } + + go session.handleGlobalRequest() + + go func() { + for newChannel := range sshChannel { + go session.HandleSessionChannel(newChannel) + } + }() + + return session +} + +func (session *Session) Close() { + session.Done <- true + + session.Listener.Close() + + for _, ch := range session.ConnChannels { + if err := ch.Close(); err != nil { + fmt.Println("Error closing channel : ", err.Error()) + continue + } + } + + if err := session.Connection.Close(); err != nil { + fmt.Println("Error closing connection : ", err.Error()) + } + +}