diff --git a/server/header.go b/server/header.go new file mode 100644 index 0000000..cb7602e --- /dev/null +++ b/server/header.go @@ -0,0 +1,163 @@ +package server + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +type HeaderManager interface { + Get(key string) []byte + Set(key string, value []byte) + Remove(key string) + Finalize() []byte +} + +type ResponseHeaderFactory struct { + startLine []byte + headers map[string]string +} + +type RequestHeaderFactory struct { + Method string + Path string + Version string + startLine []byte + headers map[string]string +} + +func NewRequestHeaderFactory(r io.Reader) (*RequestHeaderFactory, error) { + br := bufio.NewReader(r) + header := &RequestHeaderFactory{ + headers: make(map[string]string), + } + + startLine, err := br.ReadString('\n') + if err != nil { + return nil, err + } + startLine = strings.TrimRight(startLine, "\r\n") + header.startLine = []byte(startLine) + + parts := strings.Split(startLine, " ") + if len(parts) < 3 { + return nil, fmt.Errorf("invalid request line") + } + + header.Method = parts[0] + header.Path = parts[1] + header.Version = parts[2] + + for { + line, err := br.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.TrimRight(line, "\r\n") + + if line == "" { + break + } + + kv := strings.SplitN(line, ":", 2) + if len(kv) != 2 { + continue + } + header.headers[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) + } + + return header, nil +} + +func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory { + header := &ResponseHeaderFactory{ + startLine: nil, + headers: make(map[string]string), + } + lines := bytes.Split(startLine, []byte("\r\n")) + if len(lines) == 0 { + return header + } + header.startLine = lines[0] + for _, h := range lines[1:] { + if len(h) == 0 { + continue + } + + parts := bytes.SplitN(h, []byte(":"), 2) + if len(parts) < 2 { + continue + } + + key := parts[0] + val := bytes.TrimSpace(parts[1]) + header.headers[string(key)] = string(val) + } + return header +} + +func (resp *ResponseHeaderFactory) Get(key string) string { + return resp.headers[key] +} + +func (resp *ResponseHeaderFactory) Set(key string, value string) { + resp.headers[key] = value +} + +func (resp *ResponseHeaderFactory) Remove(key string) { + delete(resp.headers, key) +} + +func (resp *ResponseHeaderFactory) Finalize() []byte { + var buf bytes.Buffer + + buf.Write(resp.startLine) + buf.WriteString("\r\n") + + for key, val := range resp.headers { + buf.WriteString(key) + buf.WriteString(": ") + buf.WriteString(val) + buf.WriteString("\r\n") + } + + buf.WriteString("\r\n") + return buf.Bytes() +} + +func (req *RequestHeaderFactory) Get(key string) string { + val, ok := req.headers[key] + if !ok { + return "" + } + return val +} + +func (req *RequestHeaderFactory) Set(key string, value string) { + req.headers[key] = value +} + +func (req *RequestHeaderFactory) Remove(key string) { + delete(req.headers, key) +} + +func (req *RequestHeaderFactory) Finalize() []byte { + var buf bytes.Buffer + + buf.Write(req.startLine) + buf.WriteString("\r\n") + + req.headers["X-HF"] = "modified" + + for key, val := range req.headers { + buf.WriteString(key) + buf.WriteString(": ") + buf.WriteString(val) + buf.WriteString("\r\n") + } + + buf.WriteString("\r\n") + return buf.Bytes() +} diff --git a/server/http.go b/server/http.go index 49e23dd..13f341c 100644 --- a/server/http.go +++ b/server/http.go @@ -5,16 +5,26 @@ import ( "bytes" "errors" "fmt" + "io" "log" "net" "net/http" + "regexp" + "strconv" "strings" + "time" "tunnel_pls/session" "tunnel_pls/utils" "github.com/gorilla/websocket" + "golang.org/x/crypto/ssh" ) +var BAD_GATEWAY_RESPONSE = []byte("HTTP/1.1 502 Bad Gateway\r\n" + + "Content-Length: 11\r\n" + + "Content-Type: text/plain\r\n\r\n" + + "Bad Gateway") + var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -70,6 +80,172 @@ func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return w.conn, rw, nil } +type CustomWriter struct { + RemoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + Requests []*RequestContext + interaction *session.Interaction +} + +type RequestContext struct { + Host string + Path string + Method string + Chunked bool + Tail []byte + ContentSize int + Written int +} + +func (cw *CustomWriter) Read(p []byte) (int, error) { + read, err := cw.reader.Read(p) + test := bytes.NewReader(p) + reqhf, _ := NewRequestHeaderFactory(test) + if reqhf != nil { + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + Tail: make([]byte, 5), + ContentSize: 0, + Written: 0, + }) + } + return read, err +} + +func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { + return &CustomWriter{ + RemoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + interaction: nil, + } +} + +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} // HTTP HEADER DELIMITER `\r\n\r\n` +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + if len(lines) < 1 { + return false + } + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + if !bytes.Contains(line, []byte(":")) { + return false + } + } + return true +} + +func (cw *CustomWriter) Write(p []byte) (int, error) { + if len(p) == len(BAD_GATEWAY_RESPONSE) && bytes.Equal(p, BAD_GATEWAY_RESPONSE) { + return cw.writer.Write(p) + } + + cw.buf = append(cw.buf, p...) + timestamp := time.Now().UTC().Format(time.RFC3339) + // TODO: implement middleware buat cache system dll + if idx := bytes.Index(cw.buf, DELIMITER); idx != -1 { + header := cw.buf[:idx+len(DELIMITER)] + body := cw.buf[idx+len(DELIMITER):] + + if isHTTPHeader(header) { + resphf := NewResponseHeaderFactory(header) + resphf.Set("Server", "Tunnel Please") + + if resphf.Get("Transfer-Encoding") == "chunked" { + cw.Requests[0].Chunked = true + } + if resphf.Get("Content-Length") != "" { + bodySize, err := strconv.Atoi(resphf.Get("Content-Length")) + if err != nil { + log.Printf("Error parsing Content-Length: %v", err) + cw.Requests[0].ContentSize = -1 + } else { + cw.Requests[0].ContentSize = bodySize + } + } else { + cw.Requests[0].ContentSize = -1 + } + + header = resphf.Finalize() + _, err := cw.writer.Write(header) + if err != nil { + return 0, err + } + + if len(body) > 0 { + _, err := cw.writer.Write(body) + if err != nil { + return 0, err + } + } + req := cw.Requests[0] + req.Written += len(body) + + if req.Chunked { + req.Tail = append(req.Tail, p[len(p)-5:]...) + if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else if req.ContentSize != -1 { + if req.Written >= req.ContentSize { + cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else { + cw.Requests = cw.Requests[1:] + } + cw.buf = nil + return len(p), nil + } + } + cw.buf = nil + n, err := cw.writer.Write(p) + if err != nil { + return n, err + } + + req := cw.Requests[0] + req.Written += len(p) + if req.Chunked { + req.Tail = append(req.Tail, p[len(p)-5:]...) + if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else if req.ContentSize != -1 { + if req.Written >= req.ContentSize { + cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else { + cw.Requests = cw.Requests[1:] + } + + return n, nil +} + +func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { + cw.interaction = interaction +} + var redirectTLS = false var allowedCors = make(map[string]bool) var isAllowedAllCors = false @@ -123,14 +299,30 @@ func NewHTTPServer() error { } func Handler(conn net.Conn) { - reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 8192) + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + return + } + return + }() + + dstReader := bufio.NewReader(conn) + reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { - log.Println("Failed to peek headers:", err) return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - host := strings.Split(parseHostFromHeader(headers), ".") + // Initial Requests + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + }) + host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) if err != nil { @@ -166,7 +358,7 @@ func Handler(conn net.Conn) { } if slug == "ping" { - req, err := http.ReadRequest(reader) + req, err := http.ReadRequest(dstReader) if err != nil { log.Println("failed to parse HTTP request:", err) return @@ -218,39 +410,61 @@ func Handler(conn net.Conn) { return } - sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - }, sshSession.Conn) + forwardRequest(cw, reqhf, sshSession) return } -func peekUntilHeaders(reader *bufio.Reader, maxBytes int) ([]byte, error) { - var buf []byte - for { - n := len(buf) + 1 - if n > maxBytes { - return buf, nil - } - - peek, err := reader.Peek(n) +func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { + cw.AddInteraction(sshSession.Interaction) + originHost, originPort := ParseAddr(cw.RemoteAddr.String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) + channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + sendBadGatewayResponse(cw) + return + } + defer func(channel ssh.Channel) { + err := channel.Close() if err != nil { - return nil, err + if errors.Is(err, io.EOF) { + sendBadGatewayResponse(cw) + return + } + log.Println("Failed to close connection:", err) + return } - buf = peek + }(channel) - if bytes.Contains(buf, []byte("\r\n\r\n")) { - return buf, nil + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in request handler: %v", r) + } + }() + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } } + }() + + _, err = channel.Write(initialRequest.Finalize()) + if err != nil { + log.Printf("Failed to write forwarded-tcpip:", err) + return } + + sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) + return } -func parseHostFromHeader(data []byte) string { - lines := strings.Split(string(data), "\r\n") - for _, line := range lines { - if strings.HasPrefix(strings.ToLower(line), "host:") { - return strings.TrimSpace(strings.TrimPrefix(line, "Host:")) - } +func sendBadGatewayResponse(writer io.Writer) { + _, err := writer.Write(BAD_GATEWAY_RESPONSE) + if err != nil { + log.Printf("failed to write Bad Gateway response: %v", err) + return } - return "" } diff --git a/server/https.go b/server/https.go index 649287d..fe3221f 100644 --- a/server/https.go +++ b/server/https.go @@ -45,14 +45,31 @@ func NewHTTPSServer() error { } func HandlerTLS(conn net.Conn) { - reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 8192) + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + return + } + return + }() + + dstReader := bufio.NewReader(conn) + reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { - log.Println("Failed to peek headers:", err) return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - host := strings.Split(parseHostFromHeader(headers), ".") + // Initial Requests + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + }) + + host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) if err != nil { @@ -70,7 +87,7 @@ func HandlerTLS(conn net.Conn) { slug := host[0] if slug == "ping" { - req, err := http.ReadRequest(reader) + req, err := http.ReadRequest(dstReader) if err != nil { log.Println("failed to parse HTTP request:", err) return @@ -121,10 +138,6 @@ func HandlerTLS(conn net.Conn) { } return } - - sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - }, sshSession.Conn) + forwardRequest(cw, reqhf, sshSession) return } diff --git a/server/server.go b/server/server.go index 3f5e739..0e6bdb6 100644 --- a/server/server.go +++ b/server/server.go @@ -1,12 +1,16 @@ package server import ( + "bytes" + "encoding/binary" "fmt" - "golang.org/x/crypto/ssh" "log" "net" "net/http" + "strconv" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) type Server struct { @@ -54,3 +58,41 @@ func (s *Server) Start() { go s.handleConnection(conn) } } + +func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { + var buf bytes.Buffer + + writeSSHString(&buf, "localhost") + err := binary.Write(&buf, binary.BigEndian, uint32(port)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + writeSSHString(&buf, host) + err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + + return buf.Bytes() +} + +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) +} + +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) +} diff --git a/session/handler.go b/session/handler.go index b30c7ff..53c753e 100644 --- a/session/handler.go +++ b/session/handler.go @@ -1,9 +1,7 @@ package session import ( - "bufio" "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -56,8 +54,8 @@ func unregisterClient(slug string) { } func (s *SSHSession) Close() error { - if s.forwarder.Listener != nil { - err := s.forwarder.Listener.Close() + if s.Forwarder.Listener != nil { + err := s.Forwarder.Listener.Close() if err != nil && !errors.Is(err, net.ErrClosed) { return err } @@ -77,13 +75,13 @@ func (s *SSHSession) Close() error { } } - slug := s.forwarder.getSlug() + slug := s.Forwarder.getSlug() if slug != "" { unregisterClient(slug) } - if s.forwarder.TunnelType == TCP && s.forwarder.Listener != nil { - err := portUtil.Manager.SetPortStatus(s.forwarder.ForwardedPort, false) + if s.Forwarder.TunnelType == TCP && s.Forwarder.Listener != nil { + err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) if err != nil { return err } @@ -138,7 +136,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { var rawPortToBind uint32 if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { log.Println("Failed to read port from payload:", err) - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -152,7 +150,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { } if rawPortToBind > 65535 { - s.interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -168,7 +166,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { portToBind := uint16(rawPortToBind) if isBlockedPort(portToBind) { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -181,9 +179,9 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } - s.interaction.SendMessage("\033[H\033[2J") - s.lifecycle.Status = RUNNING - go s.interaction.HandleUserInput() + s.Interaction.SendMessage("\033[H\033[2J") + s.Lifecycle.Status = RUNNING + go s.Interaction.HandleUserInput() if portToBind == 80 || portToBind == 443 { s.handleHTTPForward(req, portToBind) @@ -193,7 +191,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { unassign, success := portUtil.Manager.GetUnassignedPort() portToBind = unassign if !success { - s.interaction.SendMessage(fmt.Sprintf("No available port\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("No available port\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -206,7 +204,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } } else if isUse, isExist := portUtil.Manager.GetPortStatus(portToBind); isExist && isUse { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -245,8 +243,8 @@ func isBlockedPort(port uint16) bool { } func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { - s.forwarder.TunnelType = HTTP - s.forwarder.ForwardedPort = portToBind + s.Forwarder.TunnelType = HTTP + s.Forwarder.ForwardedPort = portToBind slug := generateUniqueSlug() if slug == "" { @@ -258,7 +256,7 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.forwarder.setSlug(slug) + s.Forwarder.setSlug(slug) registerClient(slug, s) buf := new(bytes.Buffer) @@ -275,8 +273,8 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { protocol = "https" } - s.interaction.ShowWelcomeMessage() - s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) @@ -285,12 +283,12 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { } func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - s.forwarder.TunnelType = TCP + s.Forwarder.TunnelType = TCP log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -302,10 +300,10 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind } return } - s.forwarder.Listener = listener - s.forwarder.ForwardedPort = portToBind - s.interaction.ShowWelcomeMessage() - s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.forwarder.TunnelType, utils.Getenv("domain"), s.forwarder.ForwardedPort)) + s.Forwarder.Listener = listener + s.Forwarder.ForwardedPort = portToBind + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.TunnelType, utils.Getenv("domain"), s.Forwarder.ForwardedPort)) go s.acceptTCPConnections() @@ -325,7 +323,7 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind func (s *SSHSession) acceptTCPConnections() { for { - conn, err := s.forwarder.Listener.Accept() + conn, err := s.Forwarder.Listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -333,11 +331,35 @@ func (s *SSHSession) acceptTCPConnections() { log.Printf("Error accepting connection: %v", err) continue } + originHost, originPort := ParseAddr(conn.RemoteAddr().String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) + channel, reqs, err := s.Conn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return + } + defer func(channel ssh.Channel) { + err := channel.Close() + if err != nil { + log.Println("Failed to close connection:", err) + } + }(channel) - go s.HandleForwardedConnection(UserConnection{ - Reader: nil, - Writer: conn, - }, s.Conn) + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in request handler: %v", r) + } + }() + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } + } + }() + go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr()) } } @@ -369,15 +391,15 @@ func (s *SSHSession) waitForRunningStatus() { for { select { case <-ticker.C: - s.interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) + s.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) i = (i + 1) % len(frames) - if s.lifecycle.Status == RUNNING { - s.interaction.SendMessage("\r\033[K") + if s.Lifecycle.Status == RUNNING { + s.Interaction.SendMessage("\r\033[K") return } case <-timeout: - s.interaction.SendMessage("\r\033[K") - s.interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") + s.Interaction.SendMessage("\r\033[K") + s.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") err := s.Close() if err != nil { log.Printf("failed to close session: %v", err) @@ -425,137 +447,40 @@ func waitForKeyPress(connection ssh.Channel) { } } -func (s *SSHSession) HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) { - defer func(Writer net.Conn) { - err := Writer.Close() - if err != nil { - log.Println("Failed to close connection:", err) +func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { + defer func(src ssh.Channel) { + err := src.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing connection: %v", err) } - }(conn.Writer) - - log.Printf("Handling new forwarded connection from %s", conn.Writer.RemoteAddr()) - host, originPort := ParseAddr(conn.Writer.RemoteAddr().String()) - - timestamp := time.Now().Format("02/Jan/2006 15:04:05") - - payload := createForwardedTCPIPPayload(host, uint16(originPort), s.forwarder.ForwardedPort) - channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - sendBadGatewayResponse(conn.Writer) - return - } - defer func(channel ssh.Channel) { - err := channel.Close() - if err != nil { - log.Println("Failed to close connection:", err) - } - }(channel) + }(src) + log.Printf("Handling new forwarded connection from %s", remoteAddr) go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in request handler: %v", r) - } - }() - for req := range reqs { - err := req.Reply(false, nil) - if err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } - }() - - if conn.Reader == nil { - conn.Reader = bufio.NewReader(conn.Writer) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in reader copy: %v", r) - } - cancel() - }() - _, err := io.Copy(channel, conn.Reader) + _, err := io.Copy(src, dst) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("Error copying from conn.Reader to channel: %v", err) } - cancel() }() - reader := bufio.NewReader(channel) - - peekChan := make(chan error, 1) - go func() { - _, err := reader.Peek(1) - peekChan <- err - }() - - select { - case err := <-peekChan: - if err == io.EOF { - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - if err != nil { - log.Printf("Error peeking channel data: %v", err) - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - case <-time.After(5 * time.Second): - log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr()) - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - case <-ctx.Done(): - return - } - - s.interaction.SendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType, timestamp)) - - _, err = io.Copy(conn.Writer, reader) + _, err := io.Copy(dst, src) if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error copying from channel to conn.Writer: %v", err) } + return } -func sendBadGatewayResponse(writer io.Writer) { - response := "HTTP/1.1 502 Bad Gateway\r\n" + - "Content-Length: 11\r\n" + - "Content-Type: text/plain\r\n\r\n" + - "Bad Gateway" - _, err := io.Copy(writer, bytes.NewReader([]byte(response))) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return +func readSSHString(reader *bytes.Reader) (string, error) { + var length uint32 + if err := binary.Read(reader, binary.BigEndian, &length); err != nil { + return "", err } -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return + strBytes := make([]byte, length) + if _, err := reader.Read(strBytes); err != nil { + return "", err } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) + return string(strBytes), nil } func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { @@ -577,14 +502,21 @@ func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { return buf.Bytes() } -func readSSHString(reader *bytes.Reader) (string, error) { - var length uint32 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - return "", err +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return } - strBytes := make([]byte, length) - if _, err := reader.Read(strBytes); err != nil { - return "", err - } - return string(strBytes), nil + buffer.WriteString(str) +} + +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) } diff --git a/session/interaction.go b/session/interaction.go index 3ef056a..b22c87b 100644 --- a/session/interaction.go +++ b/session/interaction.go @@ -371,7 +371,7 @@ func updateClientSlug(oldSlug, newSlug string) bool { } delete(Clients, oldSlug) - client.forwarder.setSlug(newSlug) + client.Forwarder.setSlug(newSlug) Clients[newSlug] = client return true } diff --git a/session/session.go b/session/session.go index 6d21093..fc8c126 100644 --- a/session/session.go +++ b/session/session.go @@ -49,7 +49,6 @@ type ForwardingController interface { HandleHTTPForward(req *ssh.Request, port uint16) HandleTCPForward(req *ssh.Request, addr string, port uint16) AcceptTCPConnections() - HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) } type Session interface { @@ -98,9 +97,9 @@ type Interaction struct { forwarder ForwarderInfo } type SSHSession struct { - lifecycle *Lifecycle - interaction *Interaction - forwarder *Forwarder + Lifecycle *Lifecycle + Interaction *Interaction + Forwarder *Forwarder Conn *ssh.ServerConn channel ssh.Channel @@ -111,10 +110,10 @@ type SSHSession struct { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { session := SSHSession{ - lifecycle: &Lifecycle{ + Lifecycle: &Lifecycle{ Status: INITIALIZING, }, - interaction: &Interaction{ + Interaction: &Interaction{ CommandBuffer: new(bytes.Buffer), EditMode: false, EditSlug: "", @@ -124,7 +123,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan session: nil, forwarder: nil, }, - forwarder: &Forwarder{ + Forwarder: &Forwarder{ Listener: nil, TunnelType: "", ForwardedPort: 0, @@ -136,12 +135,12 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan slug: "", } - session.forwarder.getSlug = session.GetSlug - session.forwarder.setSlug = session.SetSlug - session.interaction.getSlug = session.GetSlug - session.interaction.setSlug = session.SetSlug - session.interaction.session = &session - session.interaction.forwarder = session.forwarder + session.Forwarder.getSlug = session.GetSlug + session.Forwarder.setSlug = session.SetSlug + session.Interaction.getSlug = session.GetSlug + session.Interaction.setSlug = session.SetSlug + session.Interaction.session = &session + session.Interaction.forwarder = session.Forwarder go func() { go session.waitForRunningStatus() @@ -150,8 +149,8 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ch, reqs, _ := channel.Accept() if session.channel == nil { session.channel = ch - session.interaction.channel = ch - session.lifecycle.Status = SETUP + session.Interaction.channel = ch + session.Lifecycle.Status = SETUP go session.HandleGlobalRequest(forwardingReq) } go session.HandleGlobalRequest(reqs)