From d7d6e24a42897ef5be401ecba85e9f37fb6723ee Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 17:14:17 +0700 Subject: [PATCH 1/8] feat: add header factory --- server/header.go | 163 ++++++++++++++++++++++++ server/http.go | 272 ++++++++++++++++++++++++++++++++++++----- server/https.go | 33 +++-- server/server.go | 44 ++++++- session/handler.go | 252 ++++++++++++++------------------------ session/interaction.go | 2 +- session/session.go | 29 +++-- 7 files changed, 579 insertions(+), 216 deletions(-) create mode 100644 server/header.go diff --git a/server/header.go b/server/header.go new file mode 100644 index 0000000..cb7602e --- /dev/null +++ b/server/header.go @@ -0,0 +1,163 @@ +package server + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +type HeaderManager interface { + Get(key string) []byte + Set(key string, value []byte) + Remove(key string) + Finalize() []byte +} + +type ResponseHeaderFactory struct { + startLine []byte + headers map[string]string +} + +type RequestHeaderFactory struct { + Method string + Path string + Version string + startLine []byte + headers map[string]string +} + +func NewRequestHeaderFactory(r io.Reader) (*RequestHeaderFactory, error) { + br := bufio.NewReader(r) + header := &RequestHeaderFactory{ + headers: make(map[string]string), + } + + startLine, err := br.ReadString('\n') + if err != nil { + return nil, err + } + startLine = strings.TrimRight(startLine, "\r\n") + header.startLine = []byte(startLine) + + parts := strings.Split(startLine, " ") + if len(parts) < 3 { + return nil, fmt.Errorf("invalid request line") + } + + header.Method = parts[0] + header.Path = parts[1] + header.Version = parts[2] + + for { + line, err := br.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.TrimRight(line, "\r\n") + + if line == "" { + break + } + + kv := strings.SplitN(line, ":", 2) + if len(kv) != 2 { + continue + } + header.headers[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) + } + + return header, nil +} + +func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory { + header := &ResponseHeaderFactory{ + startLine: nil, + headers: make(map[string]string), + } + lines := bytes.Split(startLine, []byte("\r\n")) + if len(lines) == 0 { + return header + } + header.startLine = lines[0] + for _, h := range lines[1:] { + if len(h) == 0 { + continue + } + + parts := bytes.SplitN(h, []byte(":"), 2) + if len(parts) < 2 { + continue + } + + key := parts[0] + val := bytes.TrimSpace(parts[1]) + header.headers[string(key)] = string(val) + } + return header +} + +func (resp *ResponseHeaderFactory) Get(key string) string { + return resp.headers[key] +} + +func (resp *ResponseHeaderFactory) Set(key string, value string) { + resp.headers[key] = value +} + +func (resp *ResponseHeaderFactory) Remove(key string) { + delete(resp.headers, key) +} + +func (resp *ResponseHeaderFactory) Finalize() []byte { + var buf bytes.Buffer + + buf.Write(resp.startLine) + buf.WriteString("\r\n") + + for key, val := range resp.headers { + buf.WriteString(key) + buf.WriteString(": ") + buf.WriteString(val) + buf.WriteString("\r\n") + } + + buf.WriteString("\r\n") + return buf.Bytes() +} + +func (req *RequestHeaderFactory) Get(key string) string { + val, ok := req.headers[key] + if !ok { + return "" + } + return val +} + +func (req *RequestHeaderFactory) Set(key string, value string) { + req.headers[key] = value +} + +func (req *RequestHeaderFactory) Remove(key string) { + delete(req.headers, key) +} + +func (req *RequestHeaderFactory) Finalize() []byte { + var buf bytes.Buffer + + buf.Write(req.startLine) + buf.WriteString("\r\n") + + req.headers["X-HF"] = "modified" + + for key, val := range req.headers { + buf.WriteString(key) + buf.WriteString(": ") + buf.WriteString(val) + buf.WriteString("\r\n") + } + + buf.WriteString("\r\n") + return buf.Bytes() +} diff --git a/server/http.go b/server/http.go index 49e23dd..13f341c 100644 --- a/server/http.go +++ b/server/http.go @@ -5,16 +5,26 @@ import ( "bytes" "errors" "fmt" + "io" "log" "net" "net/http" + "regexp" + "strconv" "strings" + "time" "tunnel_pls/session" "tunnel_pls/utils" "github.com/gorilla/websocket" + "golang.org/x/crypto/ssh" ) +var BAD_GATEWAY_RESPONSE = []byte("HTTP/1.1 502 Bad Gateway\r\n" + + "Content-Length: 11\r\n" + + "Content-Type: text/plain\r\n\r\n" + + "Bad Gateway") + var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -70,6 +80,172 @@ func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return w.conn, rw, nil } +type CustomWriter struct { + RemoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + Requests []*RequestContext + interaction *session.Interaction +} + +type RequestContext struct { + Host string + Path string + Method string + Chunked bool + Tail []byte + ContentSize int + Written int +} + +func (cw *CustomWriter) Read(p []byte) (int, error) { + read, err := cw.reader.Read(p) + test := bytes.NewReader(p) + reqhf, _ := NewRequestHeaderFactory(test) + if reqhf != nil { + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + Tail: make([]byte, 5), + ContentSize: 0, + Written: 0, + }) + } + return read, err +} + +func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { + return &CustomWriter{ + RemoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + interaction: nil, + } +} + +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} // HTTP HEADER DELIMITER `\r\n\r\n` +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + if len(lines) < 1 { + return false + } + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + if !bytes.Contains(line, []byte(":")) { + return false + } + } + return true +} + +func (cw *CustomWriter) Write(p []byte) (int, error) { + if len(p) == len(BAD_GATEWAY_RESPONSE) && bytes.Equal(p, BAD_GATEWAY_RESPONSE) { + return cw.writer.Write(p) + } + + cw.buf = append(cw.buf, p...) + timestamp := time.Now().UTC().Format(time.RFC3339) + // TODO: implement middleware buat cache system dll + if idx := bytes.Index(cw.buf, DELIMITER); idx != -1 { + header := cw.buf[:idx+len(DELIMITER)] + body := cw.buf[idx+len(DELIMITER):] + + if isHTTPHeader(header) { + resphf := NewResponseHeaderFactory(header) + resphf.Set("Server", "Tunnel Please") + + if resphf.Get("Transfer-Encoding") == "chunked" { + cw.Requests[0].Chunked = true + } + if resphf.Get("Content-Length") != "" { + bodySize, err := strconv.Atoi(resphf.Get("Content-Length")) + if err != nil { + log.Printf("Error parsing Content-Length: %v", err) + cw.Requests[0].ContentSize = -1 + } else { + cw.Requests[0].ContentSize = bodySize + } + } else { + cw.Requests[0].ContentSize = -1 + } + + header = resphf.Finalize() + _, err := cw.writer.Write(header) + if err != nil { + return 0, err + } + + if len(body) > 0 { + _, err := cw.writer.Write(body) + if err != nil { + return 0, err + } + } + req := cw.Requests[0] + req.Written += len(body) + + if req.Chunked { + req.Tail = append(req.Tail, p[len(p)-5:]...) + if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else if req.ContentSize != -1 { + if req.Written >= req.ContentSize { + cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else { + cw.Requests = cw.Requests[1:] + } + cw.buf = nil + return len(p), nil + } + } + cw.buf = nil + n, err := cw.writer.Write(p) + if err != nil { + return n, err + } + + req := cw.Requests[0] + req.Written += len(p) + if req.Chunked { + req.Tail = append(req.Tail, p[len(p)-5:]...) + if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else if req.ContentSize != -1 { + if req.Written >= req.ContentSize { + cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) + } + } else { + cw.Requests = cw.Requests[1:] + } + + return n, nil +} + +func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { + cw.interaction = interaction +} + var redirectTLS = false var allowedCors = make(map[string]bool) var isAllowedAllCors = false @@ -123,14 +299,30 @@ func NewHTTPServer() error { } func Handler(conn net.Conn) { - reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 8192) + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + return + } + return + }() + + dstReader := bufio.NewReader(conn) + reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { - log.Println("Failed to peek headers:", err) return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - host := strings.Split(parseHostFromHeader(headers), ".") + // Initial Requests + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + }) + host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) if err != nil { @@ -166,7 +358,7 @@ func Handler(conn net.Conn) { } if slug == "ping" { - req, err := http.ReadRequest(reader) + req, err := http.ReadRequest(dstReader) if err != nil { log.Println("failed to parse HTTP request:", err) return @@ -218,39 +410,61 @@ func Handler(conn net.Conn) { return } - sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - }, sshSession.Conn) + forwardRequest(cw, reqhf, sshSession) return } -func peekUntilHeaders(reader *bufio.Reader, maxBytes int) ([]byte, error) { - var buf []byte - for { - n := len(buf) + 1 - if n > maxBytes { - return buf, nil - } - - peek, err := reader.Peek(n) +func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { + cw.AddInteraction(sshSession.Interaction) + originHost, originPort := ParseAddr(cw.RemoteAddr.String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) + channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + sendBadGatewayResponse(cw) + return + } + defer func(channel ssh.Channel) { + err := channel.Close() if err != nil { - return nil, err + if errors.Is(err, io.EOF) { + sendBadGatewayResponse(cw) + return + } + log.Println("Failed to close connection:", err) + return } - buf = peek + }(channel) - if bytes.Contains(buf, []byte("\r\n\r\n")) { - return buf, nil + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in request handler: %v", r) + } + }() + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } } + }() + + _, err = channel.Write(initialRequest.Finalize()) + if err != nil { + log.Printf("Failed to write forwarded-tcpip:", err) + return } + + sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) + return } -func parseHostFromHeader(data []byte) string { - lines := strings.Split(string(data), "\r\n") - for _, line := range lines { - if strings.HasPrefix(strings.ToLower(line), "host:") { - return strings.TrimSpace(strings.TrimPrefix(line, "Host:")) - } +func sendBadGatewayResponse(writer io.Writer) { + _, err := writer.Write(BAD_GATEWAY_RESPONSE) + if err != nil { + log.Printf("failed to write Bad Gateway response: %v", err) + return } - return "" } diff --git a/server/https.go b/server/https.go index 649287d..fe3221f 100644 --- a/server/https.go +++ b/server/https.go @@ -45,14 +45,31 @@ func NewHTTPSServer() error { } func HandlerTLS(conn net.Conn) { - reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 8192) + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + return + } + return + }() + + dstReader := bufio.NewReader(conn) + reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { - log.Println("Failed to peek headers:", err) return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - host := strings.Split(parseHostFromHeader(headers), ".") + // Initial Requests + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + }) + + host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) if err != nil { @@ -70,7 +87,7 @@ func HandlerTLS(conn net.Conn) { slug := host[0] if slug == "ping" { - req, err := http.ReadRequest(reader) + req, err := http.ReadRequest(dstReader) if err != nil { log.Println("failed to parse HTTP request:", err) return @@ -121,10 +138,6 @@ func HandlerTLS(conn net.Conn) { } return } - - sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - }, sshSession.Conn) + forwardRequest(cw, reqhf, sshSession) return } diff --git a/server/server.go b/server/server.go index 3f5e739..0e6bdb6 100644 --- a/server/server.go +++ b/server/server.go @@ -1,12 +1,16 @@ package server import ( + "bytes" + "encoding/binary" "fmt" - "golang.org/x/crypto/ssh" "log" "net" "net/http" + "strconv" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) type Server struct { @@ -54,3 +58,41 @@ func (s *Server) Start() { go s.handleConnection(conn) } } + +func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { + var buf bytes.Buffer + + writeSSHString(&buf, "localhost") + err := binary.Write(&buf, binary.BigEndian, uint32(port)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + writeSSHString(&buf, host) + err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + + return buf.Bytes() +} + +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) +} + +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) +} diff --git a/session/handler.go b/session/handler.go index b30c7ff..53c753e 100644 --- a/session/handler.go +++ b/session/handler.go @@ -1,9 +1,7 @@ package session import ( - "bufio" "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -56,8 +54,8 @@ func unregisterClient(slug string) { } func (s *SSHSession) Close() error { - if s.forwarder.Listener != nil { - err := s.forwarder.Listener.Close() + if s.Forwarder.Listener != nil { + err := s.Forwarder.Listener.Close() if err != nil && !errors.Is(err, net.ErrClosed) { return err } @@ -77,13 +75,13 @@ func (s *SSHSession) Close() error { } } - slug := s.forwarder.getSlug() + slug := s.Forwarder.getSlug() if slug != "" { unregisterClient(slug) } - if s.forwarder.TunnelType == TCP && s.forwarder.Listener != nil { - err := portUtil.Manager.SetPortStatus(s.forwarder.ForwardedPort, false) + if s.Forwarder.TunnelType == TCP && s.Forwarder.Listener != nil { + err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) if err != nil { return err } @@ -138,7 +136,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { var rawPortToBind uint32 if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { log.Println("Failed to read port from payload:", err) - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -152,7 +150,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { } if rawPortToBind > 65535 { - s.interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -168,7 +166,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { portToBind := uint16(rawPortToBind) if isBlockedPort(portToBind) { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -181,9 +179,9 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } - s.interaction.SendMessage("\033[H\033[2J") - s.lifecycle.Status = RUNNING - go s.interaction.HandleUserInput() + s.Interaction.SendMessage("\033[H\033[2J") + s.Lifecycle.Status = RUNNING + go s.Interaction.HandleUserInput() if portToBind == 80 || portToBind == 443 { s.handleHTTPForward(req, portToBind) @@ -193,7 +191,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { unassign, success := portUtil.Manager.GetUnassignedPort() portToBind = unassign if !success { - s.interaction.SendMessage(fmt.Sprintf("No available port\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("No available port\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -206,7 +204,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } } else if isUse, isExist := portUtil.Manager.GetPortStatus(portToBind); isExist && isUse { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -245,8 +243,8 @@ func isBlockedPort(port uint16) bool { } func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { - s.forwarder.TunnelType = HTTP - s.forwarder.ForwardedPort = portToBind + s.Forwarder.TunnelType = HTTP + s.Forwarder.ForwardedPort = portToBind slug := generateUniqueSlug() if slug == "" { @@ -258,7 +256,7 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.forwarder.setSlug(slug) + s.Forwarder.setSlug(slug) registerClient(slug, s) buf := new(bytes.Buffer) @@ -275,8 +273,8 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { protocol = "https" } - s.interaction.ShowWelcomeMessage() - s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) @@ -285,12 +283,12 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { } func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - s.forwarder.TunnelType = TCP + s.Forwarder.TunnelType = TCP log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -302,10 +300,10 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind } return } - s.forwarder.Listener = listener - s.forwarder.ForwardedPort = portToBind - s.interaction.ShowWelcomeMessage() - s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.forwarder.TunnelType, utils.Getenv("domain"), s.forwarder.ForwardedPort)) + s.Forwarder.Listener = listener + s.Forwarder.ForwardedPort = portToBind + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.TunnelType, utils.Getenv("domain"), s.Forwarder.ForwardedPort)) go s.acceptTCPConnections() @@ -325,7 +323,7 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind func (s *SSHSession) acceptTCPConnections() { for { - conn, err := s.forwarder.Listener.Accept() + conn, err := s.Forwarder.Listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -333,11 +331,35 @@ func (s *SSHSession) acceptTCPConnections() { log.Printf("Error accepting connection: %v", err) continue } + originHost, originPort := ParseAddr(conn.RemoteAddr().String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) + channel, reqs, err := s.Conn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return + } + defer func(channel ssh.Channel) { + err := channel.Close() + if err != nil { + log.Println("Failed to close connection:", err) + } + }(channel) - go s.HandleForwardedConnection(UserConnection{ - Reader: nil, - Writer: conn, - }, s.Conn) + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in request handler: %v", r) + } + }() + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } + } + }() + go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr()) } } @@ -369,15 +391,15 @@ func (s *SSHSession) waitForRunningStatus() { for { select { case <-ticker.C: - s.interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) + s.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) i = (i + 1) % len(frames) - if s.lifecycle.Status == RUNNING { - s.interaction.SendMessage("\r\033[K") + if s.Lifecycle.Status == RUNNING { + s.Interaction.SendMessage("\r\033[K") return } case <-timeout: - s.interaction.SendMessage("\r\033[K") - s.interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") + s.Interaction.SendMessage("\r\033[K") + s.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") err := s.Close() if err != nil { log.Printf("failed to close session: %v", err) @@ -425,137 +447,40 @@ func waitForKeyPress(connection ssh.Channel) { } } -func (s *SSHSession) HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) { - defer func(Writer net.Conn) { - err := Writer.Close() - if err != nil { - log.Println("Failed to close connection:", err) +func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { + defer func(src ssh.Channel) { + err := src.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing connection: %v", err) } - }(conn.Writer) - - log.Printf("Handling new forwarded connection from %s", conn.Writer.RemoteAddr()) - host, originPort := ParseAddr(conn.Writer.RemoteAddr().String()) - - timestamp := time.Now().Format("02/Jan/2006 15:04:05") - - payload := createForwardedTCPIPPayload(host, uint16(originPort), s.forwarder.ForwardedPort) - channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - sendBadGatewayResponse(conn.Writer) - return - } - defer func(channel ssh.Channel) { - err := channel.Close() - if err != nil { - log.Println("Failed to close connection:", err) - } - }(channel) + }(src) + log.Printf("Handling new forwarded connection from %s", remoteAddr) go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in request handler: %v", r) - } - }() - for req := range reqs { - err := req.Reply(false, nil) - if err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } - }() - - if conn.Reader == nil { - conn.Reader = bufio.NewReader(conn.Writer) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in reader copy: %v", r) - } - cancel() - }() - _, err := io.Copy(channel, conn.Reader) + _, err := io.Copy(src, dst) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("Error copying from conn.Reader to channel: %v", err) } - cancel() }() - reader := bufio.NewReader(channel) - - peekChan := make(chan error, 1) - go func() { - _, err := reader.Peek(1) - peekChan <- err - }() - - select { - case err := <-peekChan: - if err == io.EOF { - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - if err != nil { - log.Printf("Error peeking channel data: %v", err) - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - case <-time.After(5 * time.Second): - log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr()) - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - case <-ctx.Done(): - return - } - - s.interaction.SendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType, timestamp)) - - _, err = io.Copy(conn.Writer, reader) + _, err := io.Copy(dst, src) if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error copying from channel to conn.Writer: %v", err) } + return } -func sendBadGatewayResponse(writer io.Writer) { - response := "HTTP/1.1 502 Bad Gateway\r\n" + - "Content-Length: 11\r\n" + - "Content-Type: text/plain\r\n\r\n" + - "Bad Gateway" - _, err := io.Copy(writer, bytes.NewReader([]byte(response))) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return +func readSSHString(reader *bytes.Reader) (string, error) { + var length uint32 + if err := binary.Read(reader, binary.BigEndian, &length); err != nil { + return "", err } -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return + strBytes := make([]byte, length) + if _, err := reader.Read(strBytes); err != nil { + return "", err } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) + return string(strBytes), nil } func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { @@ -577,14 +502,21 @@ func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { return buf.Bytes() } -func readSSHString(reader *bytes.Reader) (string, error) { - var length uint32 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - return "", err +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return } - strBytes := make([]byte, length) - if _, err := reader.Read(strBytes); err != nil { - return "", err - } - return string(strBytes), nil + buffer.WriteString(str) +} + +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) } diff --git a/session/interaction.go b/session/interaction.go index 3ef056a..b22c87b 100644 --- a/session/interaction.go +++ b/session/interaction.go @@ -371,7 +371,7 @@ func updateClientSlug(oldSlug, newSlug string) bool { } delete(Clients, oldSlug) - client.forwarder.setSlug(newSlug) + client.Forwarder.setSlug(newSlug) Clients[newSlug] = client return true } diff --git a/session/session.go b/session/session.go index 6d21093..fc8c126 100644 --- a/session/session.go +++ b/session/session.go @@ -49,7 +49,6 @@ type ForwardingController interface { HandleHTTPForward(req *ssh.Request, port uint16) HandleTCPForward(req *ssh.Request, addr string, port uint16) AcceptTCPConnections() - HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) } type Session interface { @@ -98,9 +97,9 @@ type Interaction struct { forwarder ForwarderInfo } type SSHSession struct { - lifecycle *Lifecycle - interaction *Interaction - forwarder *Forwarder + Lifecycle *Lifecycle + Interaction *Interaction + Forwarder *Forwarder Conn *ssh.ServerConn channel ssh.Channel @@ -111,10 +110,10 @@ type SSHSession struct { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { session := SSHSession{ - lifecycle: &Lifecycle{ + Lifecycle: &Lifecycle{ Status: INITIALIZING, }, - interaction: &Interaction{ + Interaction: &Interaction{ CommandBuffer: new(bytes.Buffer), EditMode: false, EditSlug: "", @@ -124,7 +123,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan session: nil, forwarder: nil, }, - forwarder: &Forwarder{ + Forwarder: &Forwarder{ Listener: nil, TunnelType: "", ForwardedPort: 0, @@ -136,12 +135,12 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan slug: "", } - session.forwarder.getSlug = session.GetSlug - session.forwarder.setSlug = session.SetSlug - session.interaction.getSlug = session.GetSlug - session.interaction.setSlug = session.SetSlug - session.interaction.session = &session - session.interaction.forwarder = session.forwarder + session.Forwarder.getSlug = session.GetSlug + session.Forwarder.setSlug = session.SetSlug + session.Interaction.getSlug = session.GetSlug + session.Interaction.setSlug = session.SetSlug + session.Interaction.session = &session + session.Interaction.forwarder = session.Forwarder go func() { go session.waitForRunningStatus() @@ -150,8 +149,8 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ch, reqs, _ := channel.Accept() if session.channel == nil { session.channel = ch - session.interaction.channel = ch - session.lifecycle.Status = SETUP + session.Interaction.channel = ch + session.Lifecycle.Status = SETUP go session.HandleGlobalRequest(forwardingReq) } go session.HandleGlobalRequest(reqs) From 52a7adc4f703d4f9c13db23829eb60cc98794093 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 18:21:33 +0700 Subject: [PATCH 2/8] feat: head ping --- server/http.go | 76 +++++++++++++------------------------------------- 1 file changed, 19 insertions(+), 57 deletions(-) diff --git a/server/http.go b/server/http.go index 13f341c..364fcaf 100644 --- a/server/http.go +++ b/server/http.go @@ -262,18 +262,6 @@ func init() { } func NewHTTPServer() error { - upgrader.CheckOrigin = func(r *http.Request) bool { - if isAllowedAllCors { - return true - } else { - isAllowed, ok := allowedCors[r.Header.Get("Origin")] - if !ok || !isAllowed { - return false - } - return true - } - } - listener, err := net.Listen("tcp", ":80") if err != nil { return errors.New("Error listening: " + err.Error()) @@ -313,15 +301,7 @@ func Handler(conn net.Conn) { if err != nil { return } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - // Initial Requests - cw.Requests = append(cw.Requests, &RequestContext{ - Host: reqhf.Get("Host"), - Path: reqhf.Path, - Method: reqhf.Method, - Chunked: false, - }) host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) @@ -329,11 +309,6 @@ func Handler(conn net.Conn) { log.Println("Failed to write 400 Bad Request:", err) return } - err = conn.Close() - if err != nil { - log.Println("Failed to close connection:", err) - return - } return } @@ -349,43 +324,22 @@ func Handler(conn net.Conn) { log.Println("Failed to write 301 Moved Permanently:", err) return } - err = conn.Close() - if err != nil { - log.Println("Failed to close connection:", err) - return - } return } if slug == "ping" { - req, err := http.ReadRequest(dstReader) + // TODO: implement cors + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) if err != nil { - log.Println("failed to parse HTTP request:", err) - return - } - rw := &connResponseWriter{conn: conn} - - wsConn, err := upgrader.Upgrade(rw, req, nil) - if err != nil { - if !strings.Contains(err.Error(), "the client is not using the websocket protocol") { - log.Println("Upgrade failed:", err) - } - err := conn.Close() - if err != nil { - log.Println("failed to close connection:", err) - return - } - return - } - - err = wsConn.WriteMessage(websocket.TextMessage, []byte("pong")) - if err != nil { - log.Println("failed to write pong:", err) - return - } - err = wsConn.Close() - if err != nil { - log.Println("websocket close failed :", err) + log.Println("Failed to write 200 OK:", err) return } return @@ -409,7 +363,15 @@ func Handler(conn net.Conn) { } return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) + // Initial Requests + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + }) forwardRequest(cw, reqhf, sshSession) return } From ecd6ab2618b7d4cf90fc26b7cbe6385d79adabfb Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 18:44:30 +0700 Subject: [PATCH 3/8] feat: head ping --- server/http.go | 1 + server/https.go | 60 +++++++++++++++++-------------------------------- 2 files changed, 22 insertions(+), 39 deletions(-) diff --git a/server/http.go b/server/http.go index 364fcaf..1088f84 100644 --- a/server/http.go +++ b/server/http.go @@ -299,6 +299,7 @@ func Handler(conn net.Conn) { dstReader := bufio.NewReader(conn) reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { + log.Printf("Error creating request header: %v", err) return } diff --git a/server/https.go b/server/https.go index fe3221f..51c09b8 100644 --- a/server/https.go +++ b/server/https.go @@ -7,12 +7,9 @@ import ( "fmt" "log" "net" - "net/http" "strings" "tunnel_pls/session" "tunnel_pls/utils" - - "github.com/gorilla/websocket" ) func NewHTTPSServer() error { @@ -57,17 +54,9 @@ func HandlerTLS(conn net.Conn) { dstReader := bufio.NewReader(conn) reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { + log.Printf("Error creating request header: %v", err) return } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - - // Initial Requests - cw.Requests = append(cw.Requests, &RequestContext{ - Host: reqhf.Get("Host"), - Path: reqhf.Path, - Method: reqhf.Method, - Chunked: false, - }) host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { @@ -87,34 +76,18 @@ func HandlerTLS(conn net.Conn) { slug := host[0] if slug == "ping" { - req, err := http.ReadRequest(dstReader) + // TODO: implement cors + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) if err != nil { - log.Println("failed to parse HTTP request:", err) - return - } - rw := &connResponseWriter{conn: conn} - - wsConn, err := upgrader.Upgrade(rw, req, nil) - if err != nil { - if !strings.Contains(err.Error(), "the client is not using the websocket protocol") { - log.Println("Upgrade failed:", err) - } - err := conn.Close() - if err != nil { - log.Println("failed to close connection:", err) - return - } - return - } - - err = wsConn.WriteMessage(websocket.TextMessage, []byte("pong")) - if err != nil { - log.Println("failed to write pong:", err) - return - } - err = wsConn.Close() - if err != nil { - log.Println("websocket close failed :", err) + log.Println("Failed to write 200 OK:", err) return } return @@ -138,6 +111,15 @@ func HandlerTLS(conn net.Conn) { } return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) + + // Initial Requests + cw.Requests = append(cw.Requests, &RequestContext{ + Host: reqhf.Get("Host"), + Path: reqhf.Path, + Method: reqhf.Method, + Chunked: false, + }) forwardRequest(cw, reqhf, sshSession) return } From b967619a3a3eaa72a90488a98a2620f28517c462 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 19:17:20 +0700 Subject: [PATCH 4/8] fix: chunk request not sent properly --- server/http.go | 83 +++++--------------------------------------------- 1 file changed, 7 insertions(+), 76 deletions(-) diff --git a/server/http.go b/server/http.go index 1088f84..319e906 100644 --- a/server/http.go +++ b/server/http.go @@ -8,7 +8,6 @@ import ( "io" "log" "net" - "net/http" "regexp" "strconv" "strings" @@ -16,7 +15,6 @@ import ( "tunnel_pls/session" "tunnel_pls/utils" - "github.com/gorilla/websocket" "golang.org/x/crypto/ssh" ) @@ -25,61 +23,6 @@ var BAD_GATEWAY_RESPONSE = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Type: text/plain\r\n\r\n" + "Bad Gateway") -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - -type connResponseWriter struct { - conn net.Conn - header http.Header - wrote bool -} - -func (w *connResponseWriter) Header() http.Header { - if w.header == nil { - w.header = make(http.Header) - } - return w.header -} - -func (w *connResponseWriter) WriteHeader(statusCode int) { - if w.wrote { - return - } - w.wrote = true - _, err := fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) - if err != nil { - log.Printf("Error writing HTTP response: %v", err) - return - } - err = w.header.Write(w.conn) - if err != nil { - log.Printf("Error writing HTTP header: %v", err) - return - } - _, err = fmt.Fprint(w.conn, "\r\n") - if err != nil { - log.Printf("Error writing HTTP header: %v", err) - return - } -} - -func (w *connResponseWriter) Write(b []byte) (int, error) { - if !w.wrote { - w.WriteHeader(http.StatusOK) - } - return w.conn.Write(b) -} - -func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - rw := bufio.NewReadWriter( - bufio.NewReader(w.conn), - bufio.NewWriter(w.conn), - ) - return w.conn, rw, nil -} - type CustomWriter struct { RemoteAddr net.Addr writer io.Writer @@ -198,10 +141,9 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { } req := cw.Requests[0] req.Written += len(body) - if req.Chunked { req.Tail = append(req.Tail, p[len(p)-5:]...) - if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + if bytes.Contains(p, []byte("0\r\n\r\n")) { cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } } else if req.ContentSize != -1 { @@ -211,6 +153,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { } } else { cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } cw.buf = nil return len(p), nil @@ -223,10 +166,10 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { } req := cw.Requests[0] - req.Written += len(p) + req.Written += n if req.Chunked { req.Tail = append(req.Tail, p[len(p)-5:]...) - if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + if bytes.Contains(p, []byte("0\r\n\r\n")) { cw.Requests = cw.Requests[1:] cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } @@ -235,10 +178,11 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { cw.Requests = cw.Requests[1:] cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } + } else { cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } - return n, nil } @@ -247,19 +191,6 @@ func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { } var redirectTLS = false -var allowedCors = make(map[string]bool) -var isAllowedAllCors = false - -func init() { - corsList := utils.Getenv("cors_list") - if corsList == "*" { - isAllowedAllCors = true - } else { - for _, allowedOrigin := range strings.Split(corsList, ",") { - allowedCors[allowedOrigin] = true - } - } -} func NewHTTPServer() error { listener, err := net.Listen("tcp", ":80") @@ -416,7 +347,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS _, err = channel.Write(initialRequest.Finalize()) if err != nil { - log.Printf("Failed to write forwarded-tcpip:", err) + log.Printf("Failed to forward request: %v", err) return } From 626b6b5febce0378d212dcf4a898c5de2e625c30 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 20:15:51 +0700 Subject: [PATCH 5/8] fix: unexpected byte size --- server/http.go | 87 +++---------------------------------------------- server/https.go | 7 ---- 2 files changed, 5 insertions(+), 89 deletions(-) diff --git a/server/http.go b/server/http.go index 319e906..d714851 100644 --- a/server/http.go +++ b/server/http.go @@ -9,7 +9,6 @@ import ( "log" "net" "regexp" - "strconv" "strings" "time" "tunnel_pls/session" @@ -29,35 +28,14 @@ type CustomWriter struct { reader io.Reader headerBuf []byte buf []byte - Requests []*RequestContext interaction *session.Interaction } -type RequestContext struct { - Host string - Path string - Method string - Chunked bool - Tail []byte - ContentSize int - Written int -} - func (cw *CustomWriter) Read(p []byte) (int, error) { read, err := cw.reader.Read(p) - test := bytes.NewReader(p) - reqhf, _ := NewRequestHeaderFactory(test) - if reqhf != nil { - cw.Requests = append(cw.Requests, &RequestContext{ - Host: reqhf.Get("Host"), - Path: reqhf.Path, - Method: reqhf.Method, - Chunked: false, - Tail: make([]byte, 5), - ContentSize: 0, - Written: 0, - }) - } + reader := bytes.NewReader(p) + reqhf, _ := NewRequestHeaderFactory(reader) + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), reqhf.Method, reqhf.Path)) return read, err } @@ -102,7 +80,6 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { } cw.buf = append(cw.buf, p...) - timestamp := time.Now().UTC().Format(time.RFC3339) // TODO: implement middleware buat cache system dll if idx := bytes.Index(cw.buf, DELIMITER); idx != -1 { header := cw.buf[:idx+len(DELIMITER)] @@ -112,21 +89,6 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { resphf := NewResponseHeaderFactory(header) resphf.Set("Server", "Tunnel Please") - if resphf.Get("Transfer-Encoding") == "chunked" { - cw.Requests[0].Chunked = true - } - if resphf.Get("Content-Length") != "" { - bodySize, err := strconv.Atoi(resphf.Get("Content-Length")) - if err != nil { - log.Printf("Error parsing Content-Length: %v", err) - cw.Requests[0].ContentSize = -1 - } else { - cw.Requests[0].ContentSize = bodySize - } - } else { - cw.Requests[0].ContentSize = -1 - } - header = resphf.Finalize() _, err := cw.writer.Write(header) if err != nil { @@ -139,22 +101,6 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return 0, err } } - req := cw.Requests[0] - req.Written += len(body) - if req.Chunked { - req.Tail = append(req.Tail, p[len(p)-5:]...) - if bytes.Contains(p, []byte("0\r\n\r\n")) { - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) - } - } else if req.ContentSize != -1 { - if req.Written >= req.ContentSize { - cw.Requests = cw.Requests[1:] - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) - } - } else { - cw.Requests = cw.Requests[1:] - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) - } cw.buf = nil return len(p), nil } @@ -165,24 +111,6 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return n, err } - req := cw.Requests[0] - req.Written += n - if req.Chunked { - req.Tail = append(req.Tail, p[len(p)-5:]...) - if bytes.Contains(p, []byte("0\r\n\r\n")) { - cw.Requests = cw.Requests[1:] - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) - } - } else if req.ContentSize != -1 { - if req.Written >= req.ContentSize { - cw.Requests = cw.Requests[1:] - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) - } - - } else { - cw.Requests = cw.Requests[1:] - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) - } return n, nil } @@ -297,13 +225,6 @@ func Handler(conn net.Conn) { } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - // Initial Requests - cw.Requests = append(cw.Requests, &RequestContext{ - Host: reqhf.Get("Host"), - Path: reqhf.Path, - Method: reqhf.Method, - Chunked: false, - }) forwardRequest(cw, reqhf, sshSession) return } @@ -351,6 +272,8 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS return } + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), initialRequest.Method, initialRequest.Path)) + sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) return } diff --git a/server/https.go b/server/https.go index 51c09b8..f4ecf99 100644 --- a/server/https.go +++ b/server/https.go @@ -113,13 +113,6 @@ func HandlerTLS(conn net.Conn) { } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - // Initial Requests - cw.Requests = append(cw.Requests, &RequestContext{ - Host: reqhf.Get("Host"), - Path: reqhf.Path, - Method: reqhf.Method, - Chunked: false, - }) forwardRequest(cw, reqhf, sshSession) return } From f59de03a5090cae8b5422e58f68f4775f22a74aa Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 21:52:23 +0700 Subject: [PATCH 6/8] fix: panic due to nil pointer when disconnecting a session --- server/http.go | 17 +++++++++++++---- session/handler.go | 15 ++------------- session/session.go | 8 ++++---- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/server/http.go b/server/http.go index d714851..f12b10d 100644 --- a/server/http.go +++ b/server/http.go @@ -17,7 +17,7 @@ import ( "golang.org/x/crypto/ssh" ) -var BAD_GATEWAY_RESPONSE = []byte("HTTP/1.1 502 Bad Gateway\r\n" + +var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" + "Bad Gateway") @@ -32,9 +32,18 @@ type CustomWriter struct { } func (cw *CustomWriter) Read(p []byte) (int, error) { + if cw == nil { + return 0, errors.New("can not read from nil CustomWriter") + } read, err := cw.reader.Read(p) reader := bytes.NewReader(p) - reqhf, _ := NewRequestHeaderFactory(reader) + reqhf, err := NewRequestHeaderFactory(reader) + if err != nil { + if errors.Is(err, io.EOF) { + return read, io.EOF + } + return 0, err + } cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), reqhf.Method, reqhf.Path)) return read, err } @@ -75,7 +84,7 @@ func isHTTPHeader(buf []byte) bool { } func (cw *CustomWriter) Write(p []byte) (int, error) { - if len(p) == len(BAD_GATEWAY_RESPONSE) && bytes.Equal(p, BAD_GATEWAY_RESPONSE) { + if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) { return cw.writer.Write(p) } @@ -279,7 +288,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } func sendBadGatewayResponse(writer io.Writer) { - _, err := writer.Write(BAD_GATEWAY_RESPONSE) + _, err := writer.Write(BadGatewayResponse) if err != nil { log.Printf("failed to write Bad Gateway response: %v", err) return diff --git a/session/handler.go b/session/handler.go index 53c753e..5c63338 100644 --- a/session/handler.go +++ b/session/handler.go @@ -18,7 +18,7 @@ import ( "golang.org/x/crypto/ssh" ) -type SessionStatus string +type Status string var forbiddenSlug = []string{ "ping", @@ -191,7 +191,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { unassign, success := portUtil.Manager.GetUnassignedPort() portToBind = unassign if !success { - s.Interaction.SendMessage(fmt.Sprintf("No available port\r\n", portToBind)) + s.Interaction.SendMessage("No available port\r\n") err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -338,19 +338,8 @@ func (s *SSHSession) acceptTCPConnections() { log.Printf("Failed to open forwarded-tcpip channel: %v", err) return } - defer func(channel ssh.Channel) { - err := channel.Close() - if err != nil { - log.Println("Failed to close connection:", err) - } - }(channel) go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in request handler: %v", r) - } - }() for req := range reqs { err := req.Reply(false, nil) if err != nil { diff --git a/session/session.go b/session/session.go index fc8c126..c44d656 100644 --- a/session/session.go +++ b/session/session.go @@ -10,9 +10,9 @@ import ( ) const ( - INITIALIZING SessionStatus = "INITIALIZING" - RUNNING SessionStatus = "RUNNING" - SETUP SessionStatus = "SETUP" + INITIALIZING Status = "INITIALIZING" + RUNNING Status = "RUNNING" + SETUP Status = "SETUP" ) type TunnelType string @@ -58,7 +58,7 @@ type Session interface { } type Lifecycle struct { - Status SessionStatus + Status Status } type Forwarder struct { From a3eb08e7aec1674cb69ffbed05888acab52bbdb9 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 22:17:14 +0700 Subject: [PATCH 7/8] fix: try writing to a close network --- server/http.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/server/http.go b/server/http.go index f12b10d..22cf81e 100644 --- a/server/http.go +++ b/server/http.go @@ -36,6 +36,9 @@ func (cw *CustomWriter) Read(p []byte) (int, error) { return 0, errors.New("can not read from nil CustomWriter") } read, err := cw.reader.Read(p) + if err != nil { + return 0, err + } reader := bytes.NewReader(p) reqhf, err := NewRequestHeaderFactory(reader) if err != nil { @@ -261,11 +264,6 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS }(channel) go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in request handler: %v", r) - } - }() for req := range reqs { err := req.Reply(false, nil) if err != nil { From 515bc305599b675e60e996aae817e5ad8e6820e9 Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 3 Dec 2025 21:14:42 +0700 Subject: [PATCH 8/8] fix: conn reader stuck when header have body --- main.go | 8 +- server/http.go | 99 ++++++++++++++++++++----- server/middleware.go | 162 +++++++++++++++++++++++++++++++++++++++++ session/forwarder.go | 37 ++++++++++ session/interaction.go | 26 +++++++ session/session.go | 56 -------------- 6 files changed, 314 insertions(+), 74 deletions(-) create mode 100644 server/middleware.go create mode 100644 session/forwarder.go diff --git a/main.go b/main.go index 1fb275c..9953588 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,20 @@ package main import ( - "golang.org/x/crypto/ssh" "log" + "net/http" + _ "net/http/pprof" "os" "tunnel_pls/server" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() sshConfig := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: "SSH-2.0-TunnlPls-1.0", diff --git a/server/http.go b/server/http.go index 22cf81e..3f1aaba 100644 --- a/server/http.go +++ b/server/http.go @@ -10,7 +10,6 @@ import ( "net" "regexp" "strings" - "time" "tunnel_pls/session" "tunnel_pls/utils" @@ -28,27 +27,66 @@ type CustomWriter struct { reader io.Reader headerBuf []byte buf []byte + respHeader *ResponseHeaderFactory + reqHeader *RequestHeaderFactory interaction *session.Interaction + respMW []ResponseMiddleware + reqStartMW []RequestMiddleware + reqEndMW []RequestMiddleware } func (cw *CustomWriter) Read(p []byte) (int, error) { - if cw == nil { - return 0, errors.New("can not read from nil CustomWriter") - } - read, err := cw.reader.Read(p) + tmp := make([]byte, len(p)) + read, err := cw.reader.Read(tmp) if err != nil { return 0, err } - reader := bytes.NewReader(p) - reqhf, err := NewRequestHeaderFactory(reader) - if err != nil { - if errors.Is(err, io.EOF) { - return read, io.EOF + + tmp = tmp[:read] + + idx := bytes.Index(tmp, DELIMITER) + if idx == -1 { + copy(p, tmp) + return read, nil + } + + header := tmp[:idx+len(DELIMITER)] + body := tmp[idx+len(DELIMITER):] + + if !isHTTPHeader(header) { + copy(p, tmp) + return read, nil + } + + for _, m := range cw.reqEndMW { + err := m.HandleRequest(cw.reqHeader) + if err != nil { + log.Printf("Error when applying request middleware: %v", err) + return 0, err } + } + + headerReader := bufio.NewReader(bytes.NewReader(header)) + reqhf, err := NewRequestHeaderFactory(headerReader) + if err != nil { return 0, err } - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), reqhf.Method, reqhf.Path)) - return read, err + + for _, m := range cw.reqStartMW { + err := m.HandleRequest(reqhf) + if err != nil { + log.Printf("Error when applying request middleware: %v", err) + return 0, err + } + } + + cw.reqHeader = reqhf + finalHeader := reqhf.Finalize() + + n := copy(p, finalHeader) + n += copy(p[n:], body) + + return n, nil } func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { @@ -99,9 +137,15 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { if isHTTPHeader(header) { resphf := NewResponseHeaderFactory(header) - resphf.Set("Server", "Tunnel Please") - + for _, m := range cw.respMW { + err := m.HandleResponse(resphf, body) + if err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return 0, err + } + } header = resphf.Finalize() + cw.respHeader = resphf _, err := cw.writer.Write(header) if err != nil { return 0, err @@ -117,12 +161,19 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return len(p), nil } } + cw.buf = nil n, err := cw.writer.Write(p) if err != nil { return n, err } - + for _, m := range cw.respMW { + err := m.HandleResponse(cw.respHeader, p) + if err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return 0, err + } + } return n, nil } @@ -272,14 +323,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } } }() - _, err = channel.Write(initialRequest.Finalize()) if err != nil { log.Printf("Failed to forward request: %v", err) return } + //TODO: Implement wrapper func buat add/remove middleware + fingerprintMiddleware := NewTunnelFingerprint() + loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr) + cw.respMW = append(cw.respMW, fingerprintMiddleware) + cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware) - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), initialRequest.Method, initialRequest.Path)) + //TODO: Tambah req Middleware + cw.reqEndMW = nil + cw.reqHeader = initialRequest + + for _, m := range cw.reqStartMW { + err := m.HandleRequest(cw.reqHeader) + if err != nil { + log.Printf("Error handling request: %v", err) + return + } + } sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) return diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000..08ee035 --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,162 @@ +package server + +import ( + "fmt" + "net" + "time" + "tunnel_pls/session" +) + +type RequestMiddleware interface { + HandleRequest(header *RequestHeaderFactory) error +} + +type ResponseMiddleware interface { + HandleResponse(header *ResponseHeaderFactory, body []byte) error +} + +type TunnelFingerprint struct{} + +func NewTunnelFingerprint() *TunnelFingerprint { + return &TunnelFingerprint{} +} +func (h *TunnelFingerprint) HandleRequest(header *RequestHeaderFactory) error { + return nil +} +func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { + header.Set("Server", "Tunnel Please") + return nil +} + +type RequestLogger struct { + interaction session.Interaction + remoteAddr net.Addr +} + +func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger { + return &RequestLogger{ + interaction: *interaction, + remoteAddr: remoteAddr, + } +} +func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error { + rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path)) + return nil +} +func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil } + +//TODO: Implement caching atau enggak +//const maxCacheSize = 50 * 1024 * 1024 +// +//type DiskCacheMiddleware struct { +// dir string +// mu sync.Mutex +// file *os.File +// path string +// cacheable bool +//} +// +//func NewDiskCacheMiddleware() *DiskCacheMiddleware { +// return &DiskCacheMiddleware{dir: "cache"} +//} +// +//func (c *DiskCacheMiddleware) ensureDir() error { +// return os.MkdirAll(c.dir, 0755) +//} +// +//func (c *DiskCacheMiddleware) cacheKey(method, path string) string { +// return fmt.Sprintf("%s_%s.cache", method, base64.URLEncoding.EncodeToString([]byte(path))) +//} +// +//func (c *DiskCacheMiddleware) filePath(method, path string) string { +// return filepath.Join(c.dir, c.cacheKey(method, path)) +//} +// +//func fileExists(path string) bool { +// _, err := os.Stat(path) +// if err == nil { +// return true +// } +// if os.IsNotExist(err) { +// return false +// } +// return false +//} +// +//func canCacheRequest(header *RequestHeaderFactory) bool { +// if header.Method != "GET" { +// return false +// } +// +// if cacheControl := header.Get("Cache-Control"); cacheControl != "" { +// if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "private") || strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "max-age=0") { +// return false +// } +// } +// +// if header.Get("Authorization") != "" { +// return false +// } +// +// if header.Get("Cookie") != "" { +// return false +// } +// +// return true +//} +// +//func (c *DiskCacheMiddleware) HandleRequest(header *RequestHeaderFactory) error { +// if !canCacheRequest(header) { +// c.cacheable = false +// return nil +// } +// +// c.cacheable = true +// _ = c.ensureDir() +// path := c.filePath(header.Method, header.Path) +// +// if fileExists(path + ".finish") { +// c.file = nil +// return nil +// } +// +// if c.file != nil { +// err := c.file.Close() +// if err != nil { +// return err +// } +// err = os.Rename(c.path, c.path+".finish") +// if err != nil { +// return err +// } +// } +// +// c.path = path +// f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) +// if err != nil { +// return err +// } +// +// c.file = f +// +// return nil +//} +// +//func (c *DiskCacheMiddleware) HandleResponse(header *ResponseHeaderFactory, body []byte) error { +// if !c.cacheable { +// return nil +// } +// +// if c.file == nil { +// header.Set("X-Cache", "HIT") +// return nil +// } +// +// _, err := c.file.Write(body) +// if err != nil { +// return err +// } +// +// header.Set("X-Cache", "MISS") +// return nil +//} diff --git a/session/forwarder.go b/session/forwarder.go new file mode 100644 index 0000000..e7abc17 --- /dev/null +++ b/session/forwarder.go @@ -0,0 +1,37 @@ +package session + +import ( + "net" + + "golang.org/x/crypto/ssh" +) + +type Forwarder struct { + Listener net.Listener + TunnelType TunnelType + ForwardedPort uint16 + + getSlug func() string + setSlug func(string) +} + +type ForwardingController interface { + HandleGlobalRequest(ch <-chan *ssh.Request) + HandleTCPIPForward(req *ssh.Request) + HandleHTTPForward(req *ssh.Request, port uint16) + HandleTCPForward(req *ssh.Request, addr string, port uint16) + AcceptTCPConnections() +} + +type ForwarderInfo interface { + GetTunnelType() TunnelType + GetForwardedPort() uint16 +} + +func (f *Forwarder) GetTunnelType() TunnelType { + return f.TunnelType +} + +func (f *Forwarder) GetForwardedPort() uint16 { + return f.ForwardedPort +} diff --git a/session/interaction.go b/session/interaction.go index b22c87b..cfa1ce1 100644 --- a/session/interaction.go +++ b/session/interaction.go @@ -12,6 +12,32 @@ import ( "golang.org/x/crypto/ssh" ) +type InteractionController interface { + SendMessage(message string) + HandleUserInput() + HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) + HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) + HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) + HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) + HandleSlugUpdateError() + ShowWelcomeMessage() + DisplaySlugEditor() +} + +type Interaction struct { + CommandBuffer *bytes.Buffer + EditMode bool + EditSlug string + channel ssh.Channel + + getSlug func() string + setSlug func(string) + + session SessionCloser + + forwarder ForwarderInfo +} + func (i *Interaction) SendMessage(message string) { if i.channel != nil { _, err := i.channel.Write([]byte(message)) diff --git a/session/session.go b/session/session.go index c44d656..2a38c6a 100644 --- a/session/session.go +++ b/session/session.go @@ -3,7 +3,6 @@ package session import ( "bytes" "log" - "net" "sync" "golang.org/x/crypto/ssh" @@ -31,26 +30,6 @@ type SessionCloser interface { Close() error } -type InteractionController interface { - SendMessage(message string) - HandleUserInput() - HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) - HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) - HandleSlugUpdateError() - ShowWelcomeMessage() - DisplaySlugEditor() -} - -type ForwardingController interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) - AcceptTCPConnections() -} - type Session interface { SessionLifecycle InteractionController @@ -61,41 +40,6 @@ type Lifecycle struct { Status Status } -type Forwarder struct { - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - - getSlug func() string - setSlug func(string) -} - -type ForwarderInfo interface { - GetTunnelType() TunnelType - GetForwardedPort() uint16 -} - -func (f *Forwarder) GetTunnelType() TunnelType { - return f.TunnelType -} - -func (f *Forwarder) GetForwardedPort() uint16 { - return f.ForwardedPort -} - -type Interaction struct { - CommandBuffer *bytes.Buffer - EditMode bool - EditSlug string - channel ssh.Channel - - getSlug func() string - setSlug func(string) - - session SessionCloser - - forwarder ForwarderInfo -} type SSHSession struct { Lifecycle *Lifecycle Interaction *Interaction