diff --git a/main.go b/main.go index 1fb275c..9953588 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,20 @@ package main import ( - "golang.org/x/crypto/ssh" "log" + "net/http" + _ "net/http/pprof" "os" "tunnel_pls/server" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() sshConfig := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: "SSH-2.0-TunnlPls-1.0", diff --git a/server/header.go b/server/header.go new file mode 100644 index 0000000..cb7602e --- /dev/null +++ b/server/header.go @@ -0,0 +1,163 @@ +package server + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +type HeaderManager interface { + Get(key string) []byte + Set(key string, value []byte) + Remove(key string) + Finalize() []byte +} + +type ResponseHeaderFactory struct { + startLine []byte + headers map[string]string +} + +type RequestHeaderFactory struct { + Method string + Path string + Version string + startLine []byte + headers map[string]string +} + +func NewRequestHeaderFactory(r io.Reader) (*RequestHeaderFactory, error) { + br := bufio.NewReader(r) + header := &RequestHeaderFactory{ + headers: make(map[string]string), + } + + startLine, err := br.ReadString('\n') + if err != nil { + return nil, err + } + startLine = strings.TrimRight(startLine, "\r\n") + header.startLine = []byte(startLine) + + parts := strings.Split(startLine, " ") + if len(parts) < 3 { + return nil, fmt.Errorf("invalid request line") + } + + header.Method = parts[0] + header.Path = parts[1] + header.Version = parts[2] + + for { + line, err := br.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.TrimRight(line, "\r\n") + + if line == "" { + break + } + + kv := strings.SplitN(line, ":", 2) + if len(kv) != 2 { + continue + } + header.headers[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) + } + + return header, nil +} + +func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory { + header := &ResponseHeaderFactory{ + startLine: nil, + headers: make(map[string]string), + } + lines := bytes.Split(startLine, []byte("\r\n")) + if len(lines) == 0 { + return header + } + header.startLine = lines[0] + for _, h := range lines[1:] { + if len(h) == 0 { + continue + } + + parts := bytes.SplitN(h, []byte(":"), 2) + if len(parts) < 2 { + continue + } + + key := parts[0] + val := bytes.TrimSpace(parts[1]) + header.headers[string(key)] = string(val) + } + return header +} + +func (resp *ResponseHeaderFactory) Get(key string) string { + return resp.headers[key] +} + +func (resp *ResponseHeaderFactory) Set(key string, value string) { + resp.headers[key] = value +} + +func (resp *ResponseHeaderFactory) Remove(key string) { + delete(resp.headers, key) +} + +func (resp *ResponseHeaderFactory) Finalize() []byte { + var buf bytes.Buffer + + buf.Write(resp.startLine) + buf.WriteString("\r\n") + + for key, val := range resp.headers { + buf.WriteString(key) + buf.WriteString(": ") + buf.WriteString(val) + buf.WriteString("\r\n") + } + + buf.WriteString("\r\n") + return buf.Bytes() +} + +func (req *RequestHeaderFactory) Get(key string) string { + val, ok := req.headers[key] + if !ok { + return "" + } + return val +} + +func (req *RequestHeaderFactory) Set(key string, value string) { + req.headers[key] = value +} + +func (req *RequestHeaderFactory) Remove(key string) { + delete(req.headers, key) +} + +func (req *RequestHeaderFactory) Finalize() []byte { + var buf bytes.Buffer + + buf.Write(req.startLine) + buf.WriteString("\r\n") + + req.headers["X-HF"] = "modified" + + for key, val := range req.headers { + buf.WriteString(key) + buf.WriteString(": ") + buf.WriteString(val) + buf.WriteString("\r\n") + } + + buf.WriteString("\r\n") + return buf.Bytes() +} diff --git a/server/http.go b/server/http.go index 49e23dd..3f1aaba 100644 --- a/server/http.go +++ b/server/http.go @@ -5,99 +5,185 @@ import ( "bytes" "errors" "fmt" + "io" "log" "net" - "net/http" + "regexp" "strings" "tunnel_pls/session" "tunnel_pls/utils" - "github.com/gorilla/websocket" + "golang.org/x/crypto/ssh" ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, +var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + + "Content-Length: 11\r\n" + + "Content-Type: text/plain\r\n\r\n" + + "Bad Gateway") + +type CustomWriter struct { + RemoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + respHeader *ResponseHeaderFactory + reqHeader *RequestHeaderFactory + interaction *session.Interaction + respMW []ResponseMiddleware + reqStartMW []RequestMiddleware + reqEndMW []RequestMiddleware } -type connResponseWriter struct { - conn net.Conn - header http.Header - wrote bool -} - -func (w *connResponseWriter) Header() http.Header { - if w.header == nil { - w.header = make(http.Header) - } - return w.header -} - -func (w *connResponseWriter) WriteHeader(statusCode int) { - if w.wrote { - return - } - w.wrote = true - _, err := fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) +func (cw *CustomWriter) Read(p []byte) (int, error) { + tmp := make([]byte, len(p)) + read, err := cw.reader.Read(tmp) if err != nil { - log.Printf("Error writing HTTP response: %v", err) - return + return 0, err } - err = w.header.Write(w.conn) - if err != nil { - log.Printf("Error writing HTTP header: %v", err) - return + + tmp = tmp[:read] + + idx := bytes.Index(tmp, DELIMITER) + if idx == -1 { + copy(p, tmp) + return read, nil } - _, err = fmt.Fprint(w.conn, "\r\n") + + header := tmp[:idx+len(DELIMITER)] + body := tmp[idx+len(DELIMITER):] + + if !isHTTPHeader(header) { + copy(p, tmp) + return read, nil + } + + for _, m := range cw.reqEndMW { + err := m.HandleRequest(cw.reqHeader) + if err != nil { + log.Printf("Error when applying request middleware: %v", err) + return 0, err + } + } + + headerReader := bufio.NewReader(bytes.NewReader(header)) + reqhf, err := NewRequestHeaderFactory(headerReader) if err != nil { - log.Printf("Error writing HTTP header: %v", err) - return + return 0, err + } + + for _, m := range cw.reqStartMW { + err := m.HandleRequest(reqhf) + if err != nil { + log.Printf("Error when applying request middleware: %v", err) + return 0, err + } + } + + cw.reqHeader = reqhf + finalHeader := reqhf.Finalize() + + n := copy(p, finalHeader) + n += copy(p[n:], body) + + return n, nil +} + +func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { + return &CustomWriter{ + RemoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + interaction: nil, } } -func (w *connResponseWriter) Write(b []byte) (int, error) { - if !w.wrote { - w.WriteHeader(http.StatusOK) +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} // HTTP HEADER DELIMITER `\r\n\r\n` +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + if len(lines) < 1 { + return false } - return w.conn.Write(b) + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + if !bytes.Contains(line, []byte(":")) { + return false + } + } + return true } -func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - rw := bufio.NewReadWriter( - bufio.NewReader(w.conn), - bufio.NewWriter(w.conn), - ) - return w.conn, rw, nil +func (cw *CustomWriter) Write(p []byte) (int, error) { + if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) { + return cw.writer.Write(p) + } + + cw.buf = append(cw.buf, p...) + // TODO: implement middleware buat cache system dll + if idx := bytes.Index(cw.buf, DELIMITER); idx != -1 { + header := cw.buf[:idx+len(DELIMITER)] + body := cw.buf[idx+len(DELIMITER):] + + if isHTTPHeader(header) { + resphf := NewResponseHeaderFactory(header) + for _, m := range cw.respMW { + err := m.HandleResponse(resphf, body) + if err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return 0, err + } + } + header = resphf.Finalize() + cw.respHeader = resphf + _, err := cw.writer.Write(header) + if err != nil { + return 0, err + } + + if len(body) > 0 { + _, err := cw.writer.Write(body) + if err != nil { + return 0, err + } + } + cw.buf = nil + return len(p), nil + } + } + + cw.buf = nil + n, err := cw.writer.Write(p) + if err != nil { + return n, err + } + for _, m := range cw.respMW { + err := m.HandleResponse(cw.respHeader, p) + if err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return 0, err + } + } + return n, nil +} + +func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { + cw.interaction = interaction } var redirectTLS = false -var allowedCors = make(map[string]bool) -var isAllowedAllCors = false - -func init() { - corsList := utils.Getenv("cors_list") - if corsList == "*" { - isAllowedAllCors = true - } else { - for _, allowedOrigin := range strings.Split(corsList, ",") { - allowedCors[allowedOrigin] = true - } - } -} func NewHTTPServer() error { - upgrader.CheckOrigin = func(r *http.Request) bool { - if isAllowedAllCors { - return true - } else { - isAllowed, ok := allowedCors[r.Header.Get("Origin")] - if !ok || !isAllowed { - return false - } - return true - } - } - listener, err := net.Listen("tcp", ":80") if err != nil { return errors.New("Error listening: " + err.Error()) @@ -123,25 +209,29 @@ func NewHTTPServer() error { } func Handler(conn net.Conn) { - reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 8192) + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + return + } + return + }() + + dstReader := bufio.NewReader(conn) + reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { - log.Println("Failed to peek headers:", err) + log.Printf("Error creating request header: %v", err) return } - host := strings.Split(parseHostFromHeader(headers), ".") + host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) if err != nil { log.Println("Failed to write 400 Bad Request:", err) return } - err = conn.Close() - if err != nil { - log.Println("Failed to close connection:", err) - return - } return } @@ -157,43 +247,22 @@ func Handler(conn net.Conn) { log.Println("Failed to write 301 Moved Permanently:", err) return } - err = conn.Close() - if err != nil { - log.Println("Failed to close connection:", err) - return - } return } if slug == "ping" { - req, err := http.ReadRequest(reader) + // TODO: implement cors + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) if err != nil { - log.Println("failed to parse HTTP request:", err) - return - } - rw := &connResponseWriter{conn: conn} - - wsConn, err := upgrader.Upgrade(rw, req, nil) - if err != nil { - if !strings.Contains(err.Error(), "the client is not using the websocket protocol") { - log.Println("Upgrade failed:", err) - } - err := conn.Close() - if err != nil { - log.Println("failed to close connection:", err) - return - } - return - } - - err = wsConn.WriteMessage(websocket.TextMessage, []byte("pong")) - if err != nil { - log.Println("failed to write pong:", err) - return - } - err = wsConn.Close() - if err != nil { - log.Println("websocket close failed :", err) + log.Println("Failed to write 200 OK:", err) return } return @@ -217,40 +286,74 @@ func Handler(conn net.Conn) { } return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - }, sshSession.Conn) + forwardRequest(cw, reqhf, sshSession) return } -func peekUntilHeaders(reader *bufio.Reader, maxBytes int) ([]byte, error) { - var buf []byte - for { - n := len(buf) + 1 - if n > maxBytes { - return buf, nil - } - - peek, err := reader.Peek(n) +func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { + cw.AddInteraction(sshSession.Interaction) + originHost, originPort := ParseAddr(cw.RemoteAddr.String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) + channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + sendBadGatewayResponse(cw) + return + } + defer func(channel ssh.Channel) { + err := channel.Close() if err != nil { - return nil, err + if errors.Is(err, io.EOF) { + sendBadGatewayResponse(cw) + return + } + log.Println("Failed to close connection:", err) + return } - buf = peek + }(channel) - if bytes.Contains(buf, []byte("\r\n\r\n")) { - return buf, nil + go func() { + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } + } + }() + _, err = channel.Write(initialRequest.Finalize()) + if err != nil { + log.Printf("Failed to forward request: %v", err) + return + } + //TODO: Implement wrapper func buat add/remove middleware + fingerprintMiddleware := NewTunnelFingerprint() + loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr) + cw.respMW = append(cw.respMW, fingerprintMiddleware) + cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware) + + //TODO: Tambah req Middleware + cw.reqEndMW = nil + cw.reqHeader = initialRequest + + for _, m := range cw.reqStartMW { + err := m.HandleRequest(cw.reqHeader) + if err != nil { + log.Printf("Error handling request: %v", err) + return } } + + sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) + return } -func parseHostFromHeader(data []byte) string { - lines := strings.Split(string(data), "\r\n") - for _, line := range lines { - if strings.HasPrefix(strings.ToLower(line), "host:") { - return strings.TrimSpace(strings.TrimPrefix(line, "Host:")) - } +func sendBadGatewayResponse(writer io.Writer) { + _, err := writer.Write(BadGatewayResponse) + if err != nil { + log.Printf("failed to write Bad Gateway response: %v", err) + return } - return "" } diff --git a/server/https.go b/server/https.go index 649287d..f4ecf99 100644 --- a/server/https.go +++ b/server/https.go @@ -7,12 +7,9 @@ import ( "fmt" "log" "net" - "net/http" "strings" "tunnel_pls/session" "tunnel_pls/utils" - - "github.com/gorilla/websocket" ) func NewHTTPSServer() error { @@ -45,14 +42,23 @@ func NewHTTPSServer() error { } func HandlerTLS(conn net.Conn) { - reader := bufio.NewReader(conn) - headers, err := peekUntilHeaders(reader, 8192) + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Error closing connection: %v", err) + return + } + return + }() + + dstReader := bufio.NewReader(conn) + reqhf, err := NewRequestHeaderFactory(dstReader) if err != nil { - log.Println("Failed to peek headers:", err) + log.Printf("Error creating request header: %v", err) return } - host := strings.Split(parseHostFromHeader(headers), ".") + host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) if err != nil { @@ -70,34 +76,18 @@ func HandlerTLS(conn net.Conn) { slug := host[0] if slug == "ping" { - req, err := http.ReadRequest(reader) + // TODO: implement cors + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) if err != nil { - log.Println("failed to parse HTTP request:", err) - return - } - rw := &connResponseWriter{conn: conn} - - wsConn, err := upgrader.Upgrade(rw, req, nil) - if err != nil { - if !strings.Contains(err.Error(), "the client is not using the websocket protocol") { - log.Println("Upgrade failed:", err) - } - err := conn.Close() - if err != nil { - log.Println("failed to close connection:", err) - return - } - return - } - - err = wsConn.WriteMessage(websocket.TextMessage, []byte("pong")) - if err != nil { - log.Println("failed to write pong:", err) - return - } - err = wsConn.Close() - if err != nil { - log.Println("websocket close failed :", err) + log.Println("Failed to write 200 OK:", err) return } return @@ -121,10 +111,8 @@ func HandlerTLS(conn net.Conn) { } return } + cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - }, sshSession.Conn) + forwardRequest(cw, reqhf, sshSession) return } diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000..08ee035 --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,162 @@ +package server + +import ( + "fmt" + "net" + "time" + "tunnel_pls/session" +) + +type RequestMiddleware interface { + HandleRequest(header *RequestHeaderFactory) error +} + +type ResponseMiddleware interface { + HandleResponse(header *ResponseHeaderFactory, body []byte) error +} + +type TunnelFingerprint struct{} + +func NewTunnelFingerprint() *TunnelFingerprint { + return &TunnelFingerprint{} +} +func (h *TunnelFingerprint) HandleRequest(header *RequestHeaderFactory) error { + return nil +} +func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { + header.Set("Server", "Tunnel Please") + return nil +} + +type RequestLogger struct { + interaction session.Interaction + remoteAddr net.Addr +} + +func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger { + return &RequestLogger{ + interaction: *interaction, + remoteAddr: remoteAddr, + } +} +func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error { + rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path)) + return nil +} +func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil } + +//TODO: Implement caching atau enggak +//const maxCacheSize = 50 * 1024 * 1024 +// +//type DiskCacheMiddleware struct { +// dir string +// mu sync.Mutex +// file *os.File +// path string +// cacheable bool +//} +// +//func NewDiskCacheMiddleware() *DiskCacheMiddleware { +// return &DiskCacheMiddleware{dir: "cache"} +//} +// +//func (c *DiskCacheMiddleware) ensureDir() error { +// return os.MkdirAll(c.dir, 0755) +//} +// +//func (c *DiskCacheMiddleware) cacheKey(method, path string) string { +// return fmt.Sprintf("%s_%s.cache", method, base64.URLEncoding.EncodeToString([]byte(path))) +//} +// +//func (c *DiskCacheMiddleware) filePath(method, path string) string { +// return filepath.Join(c.dir, c.cacheKey(method, path)) +//} +// +//func fileExists(path string) bool { +// _, err := os.Stat(path) +// if err == nil { +// return true +// } +// if os.IsNotExist(err) { +// return false +// } +// return false +//} +// +//func canCacheRequest(header *RequestHeaderFactory) bool { +// if header.Method != "GET" { +// return false +// } +// +// if cacheControl := header.Get("Cache-Control"); cacheControl != "" { +// if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "private") || strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "max-age=0") { +// return false +// } +// } +// +// if header.Get("Authorization") != "" { +// return false +// } +// +// if header.Get("Cookie") != "" { +// return false +// } +// +// return true +//} +// +//func (c *DiskCacheMiddleware) HandleRequest(header *RequestHeaderFactory) error { +// if !canCacheRequest(header) { +// c.cacheable = false +// return nil +// } +// +// c.cacheable = true +// _ = c.ensureDir() +// path := c.filePath(header.Method, header.Path) +// +// if fileExists(path + ".finish") { +// c.file = nil +// return nil +// } +// +// if c.file != nil { +// err := c.file.Close() +// if err != nil { +// return err +// } +// err = os.Rename(c.path, c.path+".finish") +// if err != nil { +// return err +// } +// } +// +// c.path = path +// f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) +// if err != nil { +// return err +// } +// +// c.file = f +// +// return nil +//} +// +//func (c *DiskCacheMiddleware) HandleResponse(header *ResponseHeaderFactory, body []byte) error { +// if !c.cacheable { +// return nil +// } +// +// if c.file == nil { +// header.Set("X-Cache", "HIT") +// return nil +// } +// +// _, err := c.file.Write(body) +// if err != nil { +// return err +// } +// +// header.Set("X-Cache", "MISS") +// return nil +//} diff --git a/server/server.go b/server/server.go index 3f5e739..0e6bdb6 100644 --- a/server/server.go +++ b/server/server.go @@ -1,12 +1,16 @@ package server import ( + "bytes" + "encoding/binary" "fmt" - "golang.org/x/crypto/ssh" "log" "net" "net/http" + "strconv" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) type Server struct { @@ -54,3 +58,41 @@ func (s *Server) Start() { go s.handleConnection(conn) } } + +func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { + var buf bytes.Buffer + + writeSSHString(&buf, "localhost") + err := binary.Write(&buf, binary.BigEndian, uint32(port)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + writeSSHString(&buf, host) + err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + + return buf.Bytes() +} + +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) +} + +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) +} diff --git a/session/forwarder.go b/session/forwarder.go new file mode 100644 index 0000000..e7abc17 --- /dev/null +++ b/session/forwarder.go @@ -0,0 +1,37 @@ +package session + +import ( + "net" + + "golang.org/x/crypto/ssh" +) + +type Forwarder struct { + Listener net.Listener + TunnelType TunnelType + ForwardedPort uint16 + + getSlug func() string + setSlug func(string) +} + +type ForwardingController interface { + HandleGlobalRequest(ch <-chan *ssh.Request) + HandleTCPIPForward(req *ssh.Request) + HandleHTTPForward(req *ssh.Request, port uint16) + HandleTCPForward(req *ssh.Request, addr string, port uint16) + AcceptTCPConnections() +} + +type ForwarderInfo interface { + GetTunnelType() TunnelType + GetForwardedPort() uint16 +} + +func (f *Forwarder) GetTunnelType() TunnelType { + return f.TunnelType +} + +func (f *Forwarder) GetForwardedPort() uint16 { + return f.ForwardedPort +} diff --git a/session/handler.go b/session/handler.go index b30c7ff..5c63338 100644 --- a/session/handler.go +++ b/session/handler.go @@ -1,9 +1,7 @@ package session import ( - "bufio" "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -20,7 +18,7 @@ import ( "golang.org/x/crypto/ssh" ) -type SessionStatus string +type Status string var forbiddenSlug = []string{ "ping", @@ -56,8 +54,8 @@ func unregisterClient(slug string) { } func (s *SSHSession) Close() error { - if s.forwarder.Listener != nil { - err := s.forwarder.Listener.Close() + if s.Forwarder.Listener != nil { + err := s.Forwarder.Listener.Close() if err != nil && !errors.Is(err, net.ErrClosed) { return err } @@ -77,13 +75,13 @@ func (s *SSHSession) Close() error { } } - slug := s.forwarder.getSlug() + slug := s.Forwarder.getSlug() if slug != "" { unregisterClient(slug) } - if s.forwarder.TunnelType == TCP && s.forwarder.Listener != nil { - err := portUtil.Manager.SetPortStatus(s.forwarder.ForwardedPort, false) + if s.Forwarder.TunnelType == TCP && s.Forwarder.Listener != nil { + err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) if err != nil { return err } @@ -138,7 +136,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { var rawPortToBind uint32 if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { log.Println("Failed to read port from payload:", err) - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -152,7 +150,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { } if rawPortToBind > 65535 { - s.interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -168,7 +166,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { portToBind := uint16(rawPortToBind) if isBlockedPort(portToBind) { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -181,9 +179,9 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } - s.interaction.SendMessage("\033[H\033[2J") - s.lifecycle.Status = RUNNING - go s.interaction.HandleUserInput() + s.Interaction.SendMessage("\033[H\033[2J") + s.Lifecycle.Status = RUNNING + go s.Interaction.HandleUserInput() if portToBind == 80 || portToBind == 443 { s.handleHTTPForward(req, portToBind) @@ -193,7 +191,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { unassign, success := portUtil.Manager.GetUnassignedPort() portToBind = unassign if !success { - s.interaction.SendMessage(fmt.Sprintf("No available port\r\n", portToBind)) + s.Interaction.SendMessage("No available port\r\n") err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -206,7 +204,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } } else if isUse, isExist := portUtil.Manager.GetPortStatus(portToBind); isExist && isUse { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -245,8 +243,8 @@ func isBlockedPort(port uint16) bool { } func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { - s.forwarder.TunnelType = HTTP - s.forwarder.ForwardedPort = portToBind + s.Forwarder.TunnelType = HTTP + s.Forwarder.ForwardedPort = portToBind slug := generateUniqueSlug() if slug == "" { @@ -258,7 +256,7 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.forwarder.setSlug(slug) + s.Forwarder.setSlug(slug) registerClient(slug, s) buf := new(bytes.Buffer) @@ -275,8 +273,8 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { protocol = "https" } - s.interaction.ShowWelcomeMessage() - s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) @@ -285,12 +283,12 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { } func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - s.forwarder.TunnelType = TCP + s.Forwarder.TunnelType = TCP log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) + s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -302,10 +300,10 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind } return } - s.forwarder.Listener = listener - s.forwarder.ForwardedPort = portToBind - s.interaction.ShowWelcomeMessage() - s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.forwarder.TunnelType, utils.Getenv("domain"), s.forwarder.ForwardedPort)) + s.Forwarder.Listener = listener + s.Forwarder.ForwardedPort = portToBind + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.TunnelType, utils.Getenv("domain"), s.Forwarder.ForwardedPort)) go s.acceptTCPConnections() @@ -325,7 +323,7 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind func (s *SSHSession) acceptTCPConnections() { for { - conn, err := s.forwarder.Listener.Accept() + conn, err := s.Forwarder.Listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -333,11 +331,24 @@ func (s *SSHSession) acceptTCPConnections() { log.Printf("Error accepting connection: %v", err) continue } + originHost, originPort := ParseAddr(conn.RemoteAddr().String()) + payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) + channel, reqs, err := s.Conn.OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return + } - go s.HandleForwardedConnection(UserConnection{ - Reader: nil, - Writer: conn, - }, s.Conn) + go func() { + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } + } + }() + go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr()) } } @@ -369,15 +380,15 @@ func (s *SSHSession) waitForRunningStatus() { for { select { case <-ticker.C: - s.interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) + s.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) i = (i + 1) % len(frames) - if s.lifecycle.Status == RUNNING { - s.interaction.SendMessage("\r\033[K") + if s.Lifecycle.Status == RUNNING { + s.Interaction.SendMessage("\r\033[K") return } case <-timeout: - s.interaction.SendMessage("\r\033[K") - s.interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") + s.Interaction.SendMessage("\r\033[K") + s.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") err := s.Close() if err != nil { log.Printf("failed to close session: %v", err) @@ -425,137 +436,40 @@ func waitForKeyPress(connection ssh.Channel) { } } -func (s *SSHSession) HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) { - defer func(Writer net.Conn) { - err := Writer.Close() - if err != nil { - log.Println("Failed to close connection:", err) +func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { + defer func(src ssh.Channel) { + err := src.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing connection: %v", err) } - }(conn.Writer) - - log.Printf("Handling new forwarded connection from %s", conn.Writer.RemoteAddr()) - host, originPort := ParseAddr(conn.Writer.RemoteAddr().String()) - - timestamp := time.Now().Format("02/Jan/2006 15:04:05") - - payload := createForwardedTCPIPPayload(host, uint16(originPort), s.forwarder.ForwardedPort) - channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - sendBadGatewayResponse(conn.Writer) - return - } - defer func(channel ssh.Channel) { - err := channel.Close() - if err != nil { - log.Println("Failed to close connection:", err) - } - }(channel) + }(src) + log.Printf("Handling new forwarded connection from %s", remoteAddr) go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in request handler: %v", r) - } - }() - for req := range reqs { - err := req.Reply(false, nil) - if err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } - }() - - if conn.Reader == nil { - conn.Reader = bufio.NewReader(conn.Writer) - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in reader copy: %v", r) - } - cancel() - }() - _, err := io.Copy(channel, conn.Reader) + _, err := io.Copy(src, dst) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("Error copying from conn.Reader to channel: %v", err) } - cancel() }() - reader := bufio.NewReader(channel) - - peekChan := make(chan error, 1) - go func() { - _, err := reader.Peek(1) - peekChan <- err - }() - - select { - case err := <-peekChan: - if err == io.EOF { - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - if err != nil { - log.Printf("Error peeking channel data: %v", err) - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - case <-time.After(5 * time.Second): - log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr()) - s.interaction.SendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - case <-ctx.Done(): - return - } - - s.interaction.SendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.forwarder.TunnelType, timestamp)) - - _, err = io.Copy(conn.Writer, reader) + _, err := io.Copy(dst, src) if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error copying from channel to conn.Writer: %v", err) } + return } -func sendBadGatewayResponse(writer io.Writer) { - response := "HTTP/1.1 502 Bad Gateway\r\n" + - "Content-Length: 11\r\n" + - "Content-Type: text/plain\r\n\r\n" + - "Bad Gateway" - _, err := io.Copy(writer, bytes.NewReader([]byte(response))) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return +func readSSHString(reader *bytes.Reader) (string, error) { + var length uint32 + if err := binary.Read(reader, binary.BigEndian, &length); err != nil { + return "", err } -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return + strBytes := make([]byte, length) + if _, err := reader.Read(strBytes); err != nil { + return "", err } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) + return string(strBytes), nil } func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { @@ -577,14 +491,21 @@ func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { return buf.Bytes() } -func readSSHString(reader *bytes.Reader) (string, error) { - var length uint32 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - return "", err +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return } - strBytes := make([]byte, length) - if _, err := reader.Read(strBytes); err != nil { - return "", err - } - return string(strBytes), nil + buffer.WriteString(str) +} + +func ParseAddr(addr string) (string, uint32) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint32(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint32(port) } diff --git a/session/interaction.go b/session/interaction.go index 3ef056a..cfa1ce1 100644 --- a/session/interaction.go +++ b/session/interaction.go @@ -12,6 +12,32 @@ import ( "golang.org/x/crypto/ssh" ) +type InteractionController interface { + SendMessage(message string) + HandleUserInput() + HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) + HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) + HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) + HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) + HandleSlugUpdateError() + ShowWelcomeMessage() + DisplaySlugEditor() +} + +type Interaction struct { + CommandBuffer *bytes.Buffer + EditMode bool + EditSlug string + channel ssh.Channel + + getSlug func() string + setSlug func(string) + + session SessionCloser + + forwarder ForwarderInfo +} + func (i *Interaction) SendMessage(message string) { if i.channel != nil { _, err := i.channel.Write([]byte(message)) @@ -371,7 +397,7 @@ func updateClientSlug(oldSlug, newSlug string) bool { } delete(Clients, oldSlug) - client.forwarder.setSlug(newSlug) + client.Forwarder.setSlug(newSlug) Clients[newSlug] = client return true } diff --git a/session/session.go b/session/session.go index 6d21093..2a38c6a 100644 --- a/session/session.go +++ b/session/session.go @@ -3,16 +3,15 @@ package session import ( "bytes" "log" - "net" "sync" "golang.org/x/crypto/ssh" ) const ( - INITIALIZING SessionStatus = "INITIALIZING" - RUNNING SessionStatus = "RUNNING" - SETUP SessionStatus = "SETUP" + INITIALIZING Status = "INITIALIZING" + RUNNING Status = "RUNNING" + SETUP Status = "SETUP" ) type TunnelType string @@ -31,27 +30,6 @@ type SessionCloser interface { Close() error } -type InteractionController interface { - SendMessage(message string) - HandleUserInput() - HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) - HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) - HandleSlugUpdateError() - ShowWelcomeMessage() - DisplaySlugEditor() -} - -type ForwardingController interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) - AcceptTCPConnections() - HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) -} - type Session interface { SessionLifecycle InteractionController @@ -59,48 +37,13 @@ type Session interface { } type Lifecycle struct { - Status SessionStatus + Status Status } -type Forwarder struct { - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - - getSlug func() string - setSlug func(string) -} - -type ForwarderInfo interface { - GetTunnelType() TunnelType - GetForwardedPort() uint16 -} - -func (f *Forwarder) GetTunnelType() TunnelType { - return f.TunnelType -} - -func (f *Forwarder) GetForwardedPort() uint16 { - return f.ForwardedPort -} - -type Interaction struct { - CommandBuffer *bytes.Buffer - EditMode bool - EditSlug string - channel ssh.Channel - - getSlug func() string - setSlug func(string) - - session SessionCloser - - forwarder ForwarderInfo -} type SSHSession struct { - lifecycle *Lifecycle - interaction *Interaction - forwarder *Forwarder + Lifecycle *Lifecycle + Interaction *Interaction + Forwarder *Forwarder Conn *ssh.ServerConn channel ssh.Channel @@ -111,10 +54,10 @@ type SSHSession struct { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { session := SSHSession{ - lifecycle: &Lifecycle{ + Lifecycle: &Lifecycle{ Status: INITIALIZING, }, - interaction: &Interaction{ + Interaction: &Interaction{ CommandBuffer: new(bytes.Buffer), EditMode: false, EditSlug: "", @@ -124,7 +67,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan session: nil, forwarder: nil, }, - forwarder: &Forwarder{ + Forwarder: &Forwarder{ Listener: nil, TunnelType: "", ForwardedPort: 0, @@ -136,12 +79,12 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan slug: "", } - session.forwarder.getSlug = session.GetSlug - session.forwarder.setSlug = session.SetSlug - session.interaction.getSlug = session.GetSlug - session.interaction.setSlug = session.SetSlug - session.interaction.session = &session - session.interaction.forwarder = session.forwarder + session.Forwarder.getSlug = session.GetSlug + session.Forwarder.setSlug = session.SetSlug + session.Interaction.getSlug = session.GetSlug + session.Interaction.setSlug = session.SetSlug + session.Interaction.session = &session + session.Interaction.forwarder = session.Forwarder go func() { go session.waitForRunningStatus() @@ -150,8 +93,8 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ch, reqs, _ := channel.Accept() if session.channel == nil { session.channel = ch - session.interaction.channel = ch - session.lifecycle.Status = SETUP + session.Interaction.channel = ch + session.Lifecycle.Status = SETUP go session.HandleGlobalRequest(forwardingReq) } go session.HandleGlobalRequest(reqs)