diff --git a/main.go b/main.go index 04de238..3dd3811 100644 --- a/main.go +++ b/main.go @@ -10,28 +10,25 @@ import ( ) func main() { + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + sshConfig := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: "SSH-2.0-TunnlPls-1.0", - PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - return nil, nil - }, } - log.SetOutput(os.Stdout) - log.SetFlags(log.LstdFlags | log.Lshortfile) - privateBytes, err := os.ReadFile(utils.Getenv("ssh_private_key")) if err != nil { - log.Fatalf("Failed to load private key : %s", err.Error()) + log.Fatalf("Failed to load private key: %s", err) } private, err := ssh.ParsePrivateKey(privateBytes) if err != nil { - log.Fatal("Failed to parse private key") + log.Fatalf("Failed to parse private key: %s", err) } sshConfig.AddHostKey(private) - app := server.NewServer(*sshConfig) + app := server.NewServer(sshConfig) app.Start() } diff --git a/server/header.go b/server/header.go index cb7602e..326d617 100644 --- a/server/header.go +++ b/server/header.go @@ -149,8 +149,6 @@ func (req *RequestHeaderFactory) Finalize() []byte { buf.Write(req.startLine) buf.WriteString("\r\n") - req.headers["X-HF"] = "modified" - for key, val := range req.headers { buf.WriteString(key) buf.WriteString(": ") diff --git a/server/http.go b/server/http.go index 0960932..8a0ba58 100644 --- a/server/http.go +++ b/server/http.go @@ -30,6 +30,7 @@ type CustomWriter struct { respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware + overflow []byte } func (cw *CustomWriter) SetInteraction(interaction Interaction) { @@ -37,9 +38,17 @@ func (cw *CustomWriter) SetInteraction(interaction Interaction) { } func (cw *CustomWriter) Read(p []byte) (int, error) { + if len(cw.overflow) > 0 { + n := copy(p, cw.overflow) + cw.overflow = cw.overflow[n:] + if len(cw.overflow) == 0 { + cw.overflow = nil + } + return n, nil + } tmp := make([]byte, len(p)) read, err := cw.reader.Read(tmp) - if err != nil { + if read == 0 && err != nil { return 0, err } @@ -48,6 +57,9 @@ func (cw *CustomWriter) Read(p []byte) (int, error) { idx := bytes.Index(tmp, DELIMITER) if idx == -1 { copy(p, tmp) + if err != nil { + return read, err + } return read, nil } @@ -74,18 +86,24 @@ func (cw *CustomWriter) Read(p []byte) (int, error) { } for _, m := range cw.reqStartMW { - err := m.HandleRequest(reqhf) - if err != nil { - log.Printf("Error when applying request middleware: %v", err) - return 0, err + if mwErr := m.HandleRequest(reqhf); mwErr != nil { + log.Printf("Error when applying request middleware: %v", mwErr) + return 0, mwErr } } cw.reqHeader = reqhf finalHeader := reqhf.Finalize() - n := copy(p, finalHeader) - n += copy(p[n:], body) + combined := append(finalHeader, body...) + + n := copy(p, combined) + + if n > len(p) { + cw.overflow = make([]byte, len(combined)-n) + copy(cw.overflow, combined[n:]) + log.Printf("output buffer too small (%d vs %d)", len(p), n) + } return n, nil } @@ -106,9 +124,7 @@ var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) func isHTTPHeader(buf []byte) bool { lines := bytes.Split(buf, []byte("\r\n")) - if len(lines) < 1 { - return false - } + startLine := string(lines[0]) if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { return false @@ -118,7 +134,8 @@ func isHTTPHeader(buf []byte) bool { if len(line) == 0 { break } - if !bytes.Contains(line, []byte(":")) { + colonIdx := bytes.IndexByte(line, ':') + if colonIdx <= 0 { return false } } @@ -130,52 +147,53 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return cw.writer.Write(p) } - cw.buf = append(cw.buf, p...) - // TODO: implement middleware buat cache system dll - if idx := bytes.Index(cw.buf, DELIMITER); idx != -1 { - header := cw.buf[:idx+len(DELIMITER)] - body := cw.buf[idx+len(DELIMITER):] - - if isHTTPHeader(header) { - resphf := NewResponseHeaderFactory(header) - for _, m := range cw.respMW { - err := m.HandleResponse(resphf, body) - if err != nil { - log.Printf("Cannot apply middleware: %s\n", err) - return 0, err - } - } - header = resphf.Finalize() - cw.respHeader = resphf - _, err := cw.writer.Write(header) - if err != nil { - return 0, err - } - - if len(body) > 0 { - _, err := cw.writer.Write(body) - if err != nil { - return 0, err - } - } - cw.buf = nil - return len(p), nil + if cw.respHeader != nil { + n, err := cw.writer.Write(p) + if err != nil { + return n, err } + return n, nil } - cw.buf = nil - n, err := cw.writer.Write(p) - if err != nil { + cw.buf = append(cw.buf, p...) + + idx := bytes.Index(cw.buf, DELIMITER) + if idx == -1 { + return len(p), nil + } + + header := cw.buf[:idx+len(DELIMITER)] + body := cw.buf[idx+len(DELIMITER):] + + if !isHTTPHeader(header) { + n, err := cw.writer.Write(cw.buf) + cw.buf = nil return n, err } + + resphf := NewResponseHeaderFactory(header) for _, m := range cw.respMW { - err := m.HandleResponse(cw.respHeader, p) + err := m.HandleResponse(resphf, body) if err != nil { log.Printf("Cannot apply middleware: %s\n", err) return 0, err } } - return n, nil + header = resphf.Finalize() + cw.respHeader = resphf + + _, err := cw.writer.Write(header) + if err != nil { + return 0, err + } + if len(body) > 0 { + _, err = cw.writer.Write(body) + if err != nil { + return 0, err + } + } + cw.buf = nil + return len(p), nil } func (cw *CustomWriter) AddInteraction(interaction Interaction) { @@ -318,9 +336,11 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS //TODO: Implement wrapper func buat add/remove middleware fingerprintMiddleware := NewTunnelFingerprint() loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr) + forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr) + cw.respMW = append(cw.respMW, fingerprintMiddleware) cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware) - + cw.reqStartMW = append(cw.reqStartMW, forwardedForMiddleware) //TODO: Tambah req Middleware cw.reqEndMW = nil cw.reqHeader = initialRequest diff --git a/server/middleware.go b/server/middleware.go index a28bdab..f26504c 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -19,9 +19,7 @@ type TunnelFingerprint struct{} func NewTunnelFingerprint() *TunnelFingerprint { return &TunnelFingerprint{} } -func (h *TunnelFingerprint) HandleRequest(header *RequestHeaderFactory) error { - return nil -} + func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { header.Set("Server", "Tunnel Please") return nil @@ -44,7 +42,22 @@ func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error { return nil } -func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil } +type ForwardedFor struct { + addr net.Addr +} + +func NewForwardedFor(addr net.Addr) *ForwardedFor { + return &ForwardedFor{addr: addr} +} + +func (ff *ForwardedFor) HandleRequest(header *RequestHeaderFactory) error { + host, _, err := net.SplitHostPort(ff.addr.String()) + if err != nil { + return err + } + header.Set("X-Forwarded-For", host) + return nil +} //TODO: Implement caching atau enggak //const maxCacheSize = 50 * 1024 * 1024 diff --git a/server/server.go b/server/server.go index 9d01817..75c9b89 100644 --- a/server/server.go +++ b/server/server.go @@ -16,7 +16,7 @@ type Server struct { HttpServer *http.Server } -func NewServer(config ssh.ServerConfig) *Server { +func NewServer(config *ssh.ServerConfig) *Server { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) @@ -39,7 +39,7 @@ func NewServer(config ssh.ServerConfig) *Server { }() return &Server{ Conn: &listener, - Config: &config, + Config: config, } } diff --git a/session/handler.go b/session/handler.go index e2a77f7..db45de3 100644 --- a/session/handler.go +++ b/session/handler.go @@ -160,21 +160,29 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) - return + } + return + } + + if !registerClient(slug, s) { + log.Printf("Failed to register client with slug: %s", slug) + err := req.Reply(false, nil) + if err != nil { + log.Println("Failed to reply to request:", err) } return } s.SlugManager.Set(slug) - registerClient(slug, s) buf := new(bytes.Buffer) - err := binary.Write(buf, binary.BigEndian, uint32(80)) + err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { - log.Println("Failed to reply to request:", err) + log.Println("Failed to write port to buffer:", err) + unregisterClient(slug) return } - log.Printf("HTTP forwarding approved on port: %d", 80) + log.Printf("HTTP forwarding approved on port: %d", portToBind) domain := utils.Getenv("domain") protocol := "http" @@ -184,9 +192,11 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { s.Interaction.ShowWelcomeMessage() s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) + unregisterClient(slug) return } } @@ -194,7 +204,6 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { s.Forwarder.SetType(types.TCP) log.Printf("Requested forwarding on %s:%d", addr, portToBind) - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) @@ -209,25 +218,36 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind } return } - s.Forwarder.SetListener(listener) - s.Forwarder.SetForwardedPort(portToBind) - s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) - - go s.Forwarder.AcceptTCPConnections() buf := new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { - log.Println("Failed to reply to request:", err) + log.Println("Failed to write port to buffer:", err) + err = listener.Close() + if err != nil { + log.Printf("Failed to close listener: %s", err) + return + } return } + log.Printf("TCP forwarding approved on port: %d", portToBind) err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) + err = listener.Close() + if err != nil { + log.Printf("Failed to close listener: %s", err) + return + } return } + + s.Forwarder.SetListener(listener) + s.Forwarder.SetForwardedPort(portToBind) + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) + go s.Forwarder.AcceptTCPConnections() } func generateUniqueSlug() string { diff --git a/session/session.go b/session/session.go index e122e38..aabbe63 100644 --- a/session/session.go +++ b/session/session.go @@ -30,6 +30,8 @@ type SSHSession struct { Interaction interaction.Controller Forwarder forwarder.ForwardingController SlugManager slug.Manager + + channelOnce sync.Once } func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { @@ -61,7 +63,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan interactionManager.SetSlugModificator(updateClientSlug) forwarderManager.SetLifecycle(lifecycleManager) lifecycleManager.SetUnregisterClient(unregisterClient) - + session := &SSHSession{ Lifecycle: lifecycleManager, Interaction: interactionManager, @@ -73,20 +75,23 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan go session.Lifecycle.WaitForRunningStatus() for channel := range sshChan { - ch, reqs, _ := channel.Accept() - if session.Lifecycle.GetChannel() == nil { + ch, reqs, err := channel.Accept() + if err != nil { + log.Printf("failed to accept channel: %v", err) + continue + } + session.channelOnce.Do(func() { session.Lifecycle.SetChannel(ch) session.Interaction.SetChannel(ch) session.Lifecycle.SetStatus(types.SETUP) go session.HandleGlobalRequest(forwardingReq) - } + }) + go session.HandleGlobalRequest(reqs) } - err := session.Lifecycle.Close() - if err != nil { + if err := session.Lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } - return }() }