From 8fb19af5a6a9ab08675c7b0d676e7e722b482b35 Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 19 Jan 2026 00:13:09 +0700 Subject: [PATCH 01/11] fix: resolve copy goroutine deadlock on early connection close - Add proper CloseWrite handling to signal EOF to other goroutine - Ensure both copy goroutines terminate when either side closes - Prevent goroutine leaks for SSH forwarded-tcpip channels: - Use select with default when sending result to resultChan - Close unused SSH channels and discard requests if main goroutine has already timed out --- server/http.go | 13 +++++++++- session/forwarder/forwarder.go | 45 +++++++++++++++++++++------------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/server/http.go b/server/http.go index e2143d5..9c00b28 100644 --- a/server/http.go +++ b/server/http.go @@ -347,7 +347,18 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi go func() { channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - resultChan <- channelResult{channel, reqs, err} + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + if channel != nil { + err := channel.Close() + if err != nil { + log.Printf("Failed to close unused channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + } + } }() var channel ssh.Channel diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 5807ac4..fcbc12f 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -73,14 +73,6 @@ func (f *forwarder) AcceptTCPConnections() { continue } - if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Printf("Failed to set connection deadline: %v", err) - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - continue - } - payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) type channelResult struct { @@ -92,7 +84,18 @@ func (f *forwarder) AcceptTCPConnections() { go func() { channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) - resultChan <- channelResult{channel, reqs, err} + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + if channel != nil { + err := channel.Close() + if err != nil { + log.Printf("Failed to close unused channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + } + } }() select { @@ -104,14 +107,8 @@ func (f *forwarder) AcceptTCPConnections() { } continue } - - if err = conn.SetDeadline(time.Time{}); err != nil { - log.Printf("Failed to clear connection deadline: %v", err) - } - go ssh.DiscardRequests(result.reqs) go f.HandleConnection(conn, result.channel, conn.RemoteAddr()) - case <-time.After(5 * time.Second): log.Printf("Timeout opening forwarded-tcpip channel") if closeErr := conn.Close(); closeErr != nil { @@ -150,7 +147,18 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA defer wg.Done() _, err := copyWithBuffer(dst, src) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying src→dst: %v", err) + log.Printf("Error copying src to dst: %v", err) + } + if conn, ok := dst.(interface{ CloseWrite() error }); ok { + if err = conn.CloseWrite(); err != nil { + log.Printf("Error closing write side of dst: %v", err) + } + } else { + if closer, closerOk := dst.(io.Closer); closerOk { + if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing dst connection: %v", err) + } + } } }() @@ -158,7 +166,10 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA defer wg.Done() _, err := copyWithBuffer(src, dst) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying dst→src: %v", err) + log.Printf("Error copying dst to src: %v", err) + } + if err = src.CloseWrite(); err != nil { + log.Printf("Error closing write side of src: %v", err) } }() -- 2.49.1 From adb0264bb5c8570783011ed4dab294a515bcd13a Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 19 Jan 2026 15:12:31 +0700 Subject: [PATCH 02/11] refactor(session): simplify Start() and unify forwarding logic - Extract helper functions from Start() for better code organization - Eliminate duplication with finalizeForwarding() method - Consolidate denial logic into denyForwardingRequest() - Update all handler methods to return errors instead of logging internally - Improve error handling consistency across all operations --- session/session.go | 398 ++++++++++++++++++++++----------------------- 1 file changed, 194 insertions(+), 204 deletions(-) diff --git a/session/session.go b/session/session.go index be9e9ed..4118550 100644 --- a/session/session.go +++ b/session/session.go @@ -30,10 +30,10 @@ type Detail struct { } type Session interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) + HandleGlobalRequest(ch <-chan *ssh.Request) error + HandleTCPIPForward(req *ssh.Request) error + HandleHTTPForward(req *ssh.Request, port uint16) error + HandleTCPForward(req *ssh.Request, addr string, port uint16) error Lifecycle() lifecycle.Lifecycle Interaction() interaction.Interaction Forwarder() forwarder.Forwarder @@ -88,14 +88,15 @@ func (s *session) Slug() slug.Slug { } func (s *session) Detail() *Detail { - var tunnelType string - if s.forwarder.TunnelType() == types.HTTP { - tunnelType = "HTTP" - } else if s.forwarder.TunnelType() == types.TCP { - tunnelType = "TCP" - } else { + tunnelTypeMap := map[types.TunnelType]string{ + types.HTTP: "HTTP", + types.TCP: "TCP", + } + tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()] + if !ok { tunnelType = "UNKNOWN" } + return &Detail{ ForwardingType: tunnelType, Slug: s.slug.String(), @@ -106,55 +107,80 @@ func (s *session) Detail() *Detail { } func (s *session) Start() error { - var channel ssh.NewChannel - var ok bool - select { - case channel, ok = <-s.sshChan: - if !ok { - log.Println("Forwarding request channel closed") - return nil - } - ch, reqs, err := channel.Accept() - if err != nil { - log.Printf("failed to accept channel: %v", err) - return err - } - go s.HandleGlobalRequest(reqs) - - s.lifecycle.SetChannel(ch) - s.interaction.SetChannel(ch) - s.interaction.SetMode(types.INTERACTIVE) - case <-time.After(500 * time.Millisecond): - s.interaction.SetMode(types.HEADLESS) + if err := s.setupSessionMode(); err != nil { + return err } tcpipReq := s.waitForTCPIPForward() if tcpipReq == nil { - err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) - if err != nil { - return err - } - if err = s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - return fmt.Errorf("no forwarding Request") + return s.handleMissingForwardRequest() } - if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" { - if err := tcpipReq.Reply(false, nil); err != nil { - log.Printf("cannot reply to tcpip req: %s\n", err) - return err - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - return err - } - return nil + if s.shouldRejectUnauthorized() { + return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode")) } - s.HandleTCPIPForward(tcpipReq) + if err := s.HandleTCPIPForward(tcpipReq); err != nil { + return err + } s.interaction.Start() + return s.waitForSessionEnd() +} + +func (s *session) setupSessionMode() error { + select { + case channel, ok := <-s.sshChan: + if !ok { + log.Println("Forwarding request channel closed") + return nil + } + return s.setupInteractiveMode(channel) + case <-time.After(500 * time.Millisecond): + s.interaction.SetMode(types.HEADLESS) + return nil + } +} + +func (s *session) setupInteractiveMode(channel ssh.NewChannel) error { + ch, reqs, err := channel.Accept() + if err != nil { + log.Printf("failed to accept channel: %v", err) + return err + } + + go func() { + err = s.HandleGlobalRequest(reqs) + if err != nil { + log.Printf("global request handler error: %v", err) + } + }() + + s.lifecycle.SetChannel(ch) + s.interaction.SetChannel(ch) + s.interaction.SetMode(types.INTERACTIVE) + + return nil +} + +func (s *session) handleMissingForwardRequest() error { + err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) + if err != nil { + return err + } + if err = s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + } + return fmt.Errorf("no forwarding Request") +} + +func (s *session) shouldRejectUnauthorized() bool { + return s.interaction.Mode() == types.HEADLESS && + config.Getenv("MODE", "standalone") == "standalone" && + s.lifecycle.User() == "UNAUTHORIZED" +} + +func (s *session) waitForSessionEnd() error { if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("ssh connection closed with error: %v", err) } @@ -187,220 +213,184 @@ func (s *session) waitForTCPIPForward() *ssh.Request { } } -func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { +func (s *session) handleWindowChange(req *ssh.Request) error { + p := req.Payload + if len(p) < 16 { + log.Println("invalid window-change payload") + return req.Reply(false, nil) + } + + cols := binary.BigEndian.Uint32(p[0:4]) + rows := binary.BigEndian.Uint32(p[4:8]) + + s.interaction.SetWH(int(cols), int(rows)) + return req.Reply(true, nil) +} + +func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { for req := range GlobalRequest { switch req.Type { case "shell", "pty-req": err := req.Reply(true, nil) if err != nil { - log.Println("Failed to reply to request:", err) - return + return err } case "window-change": - p := req.Payload - if len(p) < 16 { - log.Println("invalid window-change payload") - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - return - } - cols := binary.BigEndian.Uint32(p[0:4]) - rows := binary.BigEndian.Uint32(p[4:8]) - - s.interaction.SetWH(int(cols), int(rows)) - - err := req.Reply(true, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return + if err := s.handleWindowChange(req); err != nil { + return err } default: log.Println("Unknown request type:", req.Type) err := req.Reply(false, nil) if err != nil { - log.Println("Failed to reply to request:", err) - return + return err } } } + return nil } -func (s *session) HandleTCPIPForward(req *ssh.Request) { - log.Println("PortRegistry forwarding request detected") - - fail := func(msg string) { - log.Println(msg) - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - return - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - } - - reader := bytes.NewReader(req.Payload) - - addr, err := readSSHString(reader) +func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) { + address, err = readSSHString(payloadReader) if err != nil { - fail(fmt.Sprintf("Failed to read address from payload: %v", err)) - return + return "", 0, err } var rawPortToBind uint32 - if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { - fail(fmt.Sprintf("Failed to read port from payload: %v", err)) - return + if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil { + return "", 0, err } if rawPortToBind > 65535 { - fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind)) - return + return "", 0, fmt.Errorf("port is larger than allowed port of 65535") } - portToBind := uint16(rawPortToBind) - if isBlockedPort(portToBind) { - fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind)) - return + port = uint16(rawPortToBind) + if isBlockedPort(port) { + return "", 0, fmt.Errorf("port is block") } - switch portToBind { + if port == 0 { + unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() + if !ok { + return "", 0, fmt.Errorf("no available port") + } + return address, unassigned, err + } + + return address, port, err +} + +func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error { + var errs []error + if key != nil { + s.registry.Remove(*key) + } + if listener != nil { + if err := listener.Close(); err != nil { + errs = append(errs, fmt.Errorf("close listener: %w", err)) + } + } + if err := req.Reply(false, nil); err != nil { + errs = append(errs, fmt.Errorf("reply request: %w", err)) + } + if err := s.lifecycle.Close(); err != nil { + errs = append(errs, fmt.Errorf("close session: %w", err)) + } + errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg)) + return errors.Join(errs...) +} + +func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) { + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.BigEndian, uint32(port)) + if err != nil { + return err + } + + err = req.Reply(true, buf.Bytes()) + if err != nil { + return err + } + return nil +} + +func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error { + err := s.approveForwardingRequest(req, portToBind) + if err != nil { + return err + } + + s.forwarder.SetType(tunnelType) + s.forwarder.SetForwardedPort(portToBind) + s.slug.Set(slug) + s.lifecycle.SetStatus(types.RUNNING) + + if listener != nil { + s.forwarder.SetListener(listener) + go s.forwarder.AcceptTCPConnections() + } + + return nil +} + +func (s *session) HandleTCPIPForward(req *ssh.Request) error { + reader := bytes.NewReader(req.Payload) + + address, port, err := s.parseForwardPayload(reader) + if err != nil { + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error())) + } + + switch port { case 80, 443: - s.HandleHTTPForward(req, portToBind) + return s.HandleHTTPForward(req, port) default: - s.HandleTCPForward(req, addr, portToBind) + + return s.HandleTCPForward(req, address, port) } } -func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) { - fail := func(msg string, key *types.SessionKey) { - log.Println(msg) - if key != nil { - s.registry.Remove(*key) - } - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - } - } - +func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { randomString, err := random.GenerateRandomString(20) if err != nil { - fail(fmt.Sprintf("Failed to create slug: %s", err), nil) - return + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) } key := types.SessionKey{Id: randomString, Type: types.HTTP} if !s.registry.Register(key, s) { - fail(fmt.Sprintf("Failed to register client with slug: %s", randomString), nil) - return + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString)) } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) + err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id) if err != nil { - fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key) - return + return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } - log.Printf("HTTP forwarding approved on port: %d", portToBind) - - err = req.Reply(true, buf.Bytes()) - if err != nil { - fail(fmt.Sprintf("Failed to reply to request: %v", err), &key) - return - } - - s.forwarder.SetType(types.HTTP) - s.forwarder.SetForwardedPort(portToBind) - s.slug.Set(randomString) - s.lifecycle.SetStatus(types.RUNNING) + return nil } -func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - fail := func(msg string) { - log.Println(msg) - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - return - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - } - - cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) { - log.Println(msg) - if key != nil { - s.registry.Remove(*key) - } - if port != 0 { - if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil { - log.Printf("Failed to reset port status: %v", setErr) - } - } - if listener != nil { - if closeErr := listener.Close(); closeErr != nil { - log.Printf("Failed to close listener: %v", closeErr) - } - } - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - } - _ = s.lifecycle.Close() - } - - if portToBind == 0 { - unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() - if !ok { - fail("No available port") - return - } - portToBind = unassigned - } - +func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed { - fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) - return + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } - log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil) - return + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} - if !s.registry.Register(key, s) { - cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil) - return + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id)) } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) + err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id) if err != nil { - cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key) - return + return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } - - log.Printf("TCP forwarding approved on port: %d", portToBind) - err = req.Reply(true, buf.Bytes()) - if err != nil { - cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key) - return - } - - s.forwarder.SetType(types.TCP) - s.forwarder.SetListener(listener) - s.forwarder.SetForwardedPort(portToBind) - s.slug.Set(key.Id) - s.lifecycle.SetStatus(types.RUNNING) - go s.forwarder.AcceptTCPConnections() + return nil } -func readSSHString(reader *bytes.Reader) (string, error) { +func readSSHString(reader io.Reader) (string, error) { var length uint32 if err := binary.Read(reader, binary.BigEndian, &length); err != nil { return "", err -- 2.49.1 From 27f49879af4b0293fa4a645a8938de74b782bab3 Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 19 Jan 2026 22:41:04 +0700 Subject: [PATCH 03/11] refactor(server): enhance HTTP handler modularity and fix resource leak - Rename customWriter struct to httpWriter for clarity - Add closeWriter field to properly close write side of connections - Update all cw variable references to hw - Merge handlerTLS into handler function to reduce code duplication - Extract handler into smaller, focused methods - Split Read/Write/forwardRequest into composable functions Fixes resource leak where connections weren't properly closed on the write side, matching the forwarder's CloseWrite() pattern. --- server/http.go | 483 ++++++++++++--------------------- server/https.go | 112 -------- server/httpwritter.go | 250 +++++++++++++++++ server/server.go | 3 +- session/forwarder/forwarder.go | 12 - 5 files changed, 428 insertions(+), 432 deletions(-) delete mode 100644 server/https.go create mode 100644 server/httpwritter.go diff --git a/server/http.go b/server/http.go index 9c00b28..f685d39 100644 --- a/server/http.go +++ b/server/http.go @@ -2,13 +2,12 @@ package server import ( "bufio" - "bytes" + "crypto/tls" "errors" "fmt" - "io" "log" "net" - "regexp" + "net/http" "strings" "time" "tunnel_pls/internal/config" @@ -18,214 +17,20 @@ import ( "golang.org/x/crypto/ssh" ) -type HTTPWriter interface { - io.Reader - io.Writer - GetRemoteAddr() net.Addr - GetWriter() io.Writer - AddResponseMiddleware(mw ResponseMiddleware) - AddRequestStartMiddleware(mw RequestMiddleware) - SetRequestHeader(header RequestHeaderManager) - GetRequestStartMiddleware() []RequestMiddleware -} - -type customWriter struct { - remoteAddr net.Addr - writer io.Writer - reader io.Reader - headerBuf []byte - buf []byte - respHeader ResponseHeaderManager - reqHeader RequestHeaderManager - respMW []ResponseMiddleware - reqStartMW []RequestMiddleware - reqEndMW []RequestMiddleware -} - -func (cw *customWriter) GetRemoteAddr() net.Addr { - return cw.remoteAddr -} - -func (cw *customWriter) GetWriter() io.Writer { - return cw.writer -} - -func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) { - cw.respMW = append(cw.respMW, mw) -} - -func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) { - cw.reqStartMW = append(cw.reqStartMW, mw) -} - -func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) { - cw.reqHeader = header -} - -func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware { - return cw.reqStartMW -} - -func (cw *customWriter) Read(p []byte) (int, error) { - tmp := make([]byte, len(p)) - read, err := cw.reader.Read(tmp) - if read == 0 && err != nil { - return 0, err - } - - tmp = tmp[:read] - - idx := bytes.Index(tmp, DELIMITER) - if idx == -1 { - copy(p, tmp) - if err != nil { - return read, err - } - return read, nil - } - - header := tmp[:idx+len(DELIMITER)] - body := tmp[idx+len(DELIMITER):] - - if !isHTTPHeader(header) { - copy(p, tmp) - return read, nil - } - - for _, m := range cw.reqEndMW { - err = m.HandleRequest(cw.reqHeader) - if err != nil { - log.Printf("Error when applying request middleware: %v", err) - return 0, err - } - } - - reqhf, err := NewRequestHeaderFactory(header) - if err != nil { - return 0, err - } - - for _, m := range cw.reqStartMW { - if mwErr := m.HandleRequest(reqhf); mwErr != nil { - log.Printf("Error when applying request middleware: %v", mwErr) - return 0, mwErr - } - } - - cw.reqHeader = reqhf - finalHeader := reqhf.Finalize() - - combined := append(finalHeader, body...) - - n := copy(p, combined) - - return n, nil -} - -func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { - return &customWriter{ - remoteAddr: remoteAddr, - writer: writer, - reader: reader, - buf: make([]byte, 0, 4096), - } -} - -var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} -var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) -var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) - -func isHTTPHeader(buf []byte) bool { - lines := bytes.Split(buf, []byte("\r\n")) - - startLine := string(lines[0]) - if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { - return false - } - - for _, line := range lines[1:] { - if len(line) == 0 { - break - } - colonIdx := bytes.IndexByte(line, ':') - if colonIdx <= 0 { - return false - } - } - return true -} - -func (cw *customWriter) Write(p []byte) (int, error) { - if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" { - cw.respHeader = nil - } - - if cw.respHeader != nil { - n, err := cw.writer.Write(p) - if err != nil { - return n, err - } - return n, nil - } - - cw.buf = append(cw.buf, p...) - - idx := bytes.Index(cw.buf, DELIMITER) - if idx == -1 { - return len(p), nil - } - - header := cw.buf[:idx+len(DELIMITER)] - body := cw.buf[idx+len(DELIMITER):] - - if !isHTTPHeader(header) { - _, err := cw.writer.Write(cw.buf) - cw.buf = nil - if err != nil { - return 0, err - } - return len(p), nil - } - - resphf := NewResponseHeaderFactory(header) - for _, m := range cw.respMW { - err := m.HandleResponse(resphf, body) - if err != nil { - log.Printf("Cannot apply middleware: %s\n", err) - return 0, err - } - } - header = resphf.Finalize() - cw.respHeader = resphf - - _, err := cw.writer.Write(header) - if err != nil { - return 0, err - } - if len(body) > 0 { - _, err = cw.writer.Write(body) - if err != nil { - return 0, err - } - } - cw.buf = nil - return len(p), nil -} - -var redirectTLS = false - type HTTPServer interface { ListenAndServe() error ListenAndServeTLS() error - handler(conn net.Conn) - handlerTLS(conn net.Conn) } type httpServer struct { sessionRegistry session.Registry + redirectTLS bool } -func NewHTTPServer(sessionRegistry session.Registry) HTTPServer { - return &httpServer{sessionRegistry: sessionRegistry} +func NewHTTPServer(sessionRegistry session.Registry, redirectTLS bool) HTTPServer { + return &httpServer{ + sessionRegistry: sessionRegistry, + redirectTLS: redirectTLS, + } } func (hs *httpServer) ListenAndServe() error { @@ -234,9 +39,6 @@ func (hs *httpServer) ListenAndServe() error { if err != nil { return errors.New("Error listening: " + err.Error()) } - if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" { - redirectTLS = true - } go func() { for { var conn net.Conn @@ -249,21 +51,65 @@ func (hs *httpServer) ListenAndServe() error { continue } - go hs.handler(conn) + go hs.handler(conn, false) } }() return nil } -func (hs *httpServer) handler(conn net.Conn) { - defer func() { - err := conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - log.Printf("Error closing connection: %v", err) - return +func (hs *httpServer) ListenAndServeTLS() error { + domain := config.Getenv("DOMAIN", "localhost") + httpsPort := config.Getenv("HTTPS_PORT", "8443") + + tlsConfig, err := NewTLSConfig(domain) + if err != nil { + return fmt.Errorf("failed to initialize TLS config: %w", err) + } + + ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig) + if err != nil { + return err + } + + go func() { + for { + var conn net.Conn + conn, err = ln.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + log.Println("https server closed") + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go hs.handler(conn, true) } - return }() + return nil +} + +func (hs *httpServer) redirect(conn net.Conn, status int, location string) error { + _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + + fmt.Sprintf("Location: %s", location) + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "\r\n")) + if err != nil { + return err + } + return nil +} + +func (hs *httpServer) badRequest(conn net.Conn) error { + if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + return err + } + return nil +} + +func (hs *httpServer) handler(conn net.Conn, isTLS bool) { + defer hs.closeConnection(conn) dstReader := bufio.NewReader(conn) reqhf, err := NewRequestHeaderFactory(dstReader) @@ -272,77 +118,108 @@ func (hs *httpServer) handler(conn net.Conn) { return } + slug, err := hs.extractSlug(reqhf) + if err != nil { + _ = hs.badRequest(conn) + return + } + + if hs.shouldRedirectToTLS(isTLS) { + _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) + return + } + + if hs.handlePingRequest(slug, conn) { + return + } + + sshSession, err := hs.getSession(slug) + if err != nil { + _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) + return + } + + hw := NewHTTPWriter(conn, dstReader, conn.RemoteAddr()) + hs.forwardRequest(hw, reqhf, sshSession) +} + +func (hs *httpServer) closeConnection(conn net.Conn) { + err := conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + log.Printf("Error closing connection: %v", err) + } +} + +func (hs *httpServer) extractSlug(reqhf RequestHeaderManager) (string, error) { host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { - _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - if err != nil { - log.Println("Failed to write 400 Bad Request:", err) - return - } - return + return "", errors.New("invalid host") + } + return host[0], nil +} + +func (hs *httpServer) shouldRedirectToTLS(isTLS bool) bool { + return !isTLS && hs.redirectTLS +} + +func (hs *httpServer) handlePingRequest(slug string, conn net.Conn) bool { + if slug != "ping" { + return false } - slug := host[0] - - if redirectTLS { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) + + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return - } - - if slug == "ping" { - _, err = conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - return - } - return + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) + if err != nil { + log.Println("Failed to write 200 OK:", err) } + return true +} +func (hs *httpServer) getSession(slug string) (session.Session, error) { sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ Id: slug, Type: types.HTTP, }) if err != nil { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return + return nil, err } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - forwardRequest(cw, reqhf, sshSession) - return + return sshSession, nil } -func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) +func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { + channel, err := hs.openForwardedChannel(hw, sshSession) + if err != nil { + log.Printf("Failed to establish channel: %v", err) + sshSession.Forwarder().WriteBadGatewayResponse(hw) + return + } + + hs.setupMiddlewares(hw) + + if err := hs.sendInitialRequest(hw, initialRequest, channel); err != nil { + log.Printf("Failed to forward initial request: %v", err) + return + } + + sshSession.Forwarder().HandleConnection(hw, channel, hw.RemoteAddr()) +} + +func (hs *httpServer) openForwardedChannel(hw HTTPWriter, sshSession session.Session) (ssh.Channel, error) { + payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) type channelResult struct { channel ssh.Channel reqs <-chan *ssh.Request err error } + resultChan := make(chan channelResult, 1) go func() { @@ -350,57 +227,49 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi select { case resultChan <- channelResult{channel, reqs, err}: default: - if channel != nil { - err := channel.Close() - if err != nil { - log.Printf("Failed to close unused channel: %v", err) - return - } - go ssh.DiscardRequests(reqs) - } + hs.cleanupUnusedChannel(channel, reqs) } }() - var channel ssh.Channel - var reqs <-chan *ssh.Request - select { case result := <-resultChan: if result.err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) - return + return nil, result.err } - channel = result.channel - reqs = result.reqs + go ssh.DiscardRequests(result.reqs) + return result.channel, nil case <-time.After(5 * time.Second): - log.Printf("Timeout opening forwarded-tcpip channel") - sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) - return + return nil, errors.New("timeout opening forwarded-tcpip channel") } - - go ssh.DiscardRequests(reqs) - - fingerprintMiddleware := NewTunnelFingerprint() - forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr()) - - cw.AddResponseMiddleware(fingerprintMiddleware) - cw.AddRequestStartMiddleware(forwardedForMiddleware) - cw.SetRequestHeader(initialRequest) - - for _, m := range cw.GetRequestStartMiddleware() { - if err := m.HandleRequest(initialRequest); err != nil { - log.Printf("Error handling request: %v", err) - return - } - } - - _, err := channel.Write(initialRequest.Finalize()) - if err != nil { - log.Printf("Failed to forward request: %v", err) - return - } - - sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) - return +} + +func (hs *httpServer) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { + if channel != nil { + if err := channel.Close(); err != nil { + log.Printf("Failed to close unused channel: %v", err) + } + go ssh.DiscardRequests(reqs) + } +} + +func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { + fingerprintMiddleware := NewTunnelFingerprint() + forwardedForMiddleware := NewForwardedFor(hw.RemoteAddr()) + + hw.UseResponseMiddleware(fingerprintMiddleware) + hw.UseRequestMiddleware(forwardedForMiddleware) +} + +func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeaderManager, channel ssh.Channel) error { + hw.SetRequestHeader(initialRequest) + + if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { + return fmt.Errorf("error applying request middlewares: %w", err) + } + + if _, err := channel.Write(initialRequest.Finalize()); err != nil { + return fmt.Errorf("error writing to channel: %w", err) + } + + return nil } diff --git a/server/https.go b/server/https.go deleted file mode 100644 index 2758172..0000000 --- a/server/https.go +++ /dev/null @@ -1,112 +0,0 @@ -package server - -import ( - "bufio" - "crypto/tls" - "errors" - "fmt" - "log" - "net" - "strings" - "tunnel_pls/internal/config" - "tunnel_pls/types" -) - -func (hs *httpServer) ListenAndServeTLS() error { - domain := config.Getenv("DOMAIN", "localhost") - httpsPort := config.Getenv("HTTPS_PORT", "8443") - - tlsConfig, err := NewTLSConfig(domain) - if err != nil { - return fmt.Errorf("failed to initialize TLS config: %w", err) - } - - ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig) - if err != nil { - return err - } - - go func() { - for { - var conn net.Conn - conn, err = ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - log.Println("https server closed") - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handlerTLS(conn) - } - }() - return nil -} - -func (hs *httpServer) handlerTLS(conn net.Conn) { - defer func() { - err := conn.Close() - if err != nil { - log.Printf("Error closing connection: %v", err) - return - } - return - }() - - dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeaderFactory(dstReader) - if err != nil { - log.Printf("Error creating request header: %v", err) - return - } - - host := strings.Split(reqhf.Get("Host"), ".") - if len(host) < 1 { - _, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - if err != nil { - log.Println("Failed to write 400 Bad Request:", err) - return - } - return - } - - slug := host[0] - - if slug == "ping" { - _, err = conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - return - } - return - } - - sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.HTTP, - }) - if err != nil { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return - } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - forwardRequest(cw, reqhf, sshSession) - return -} diff --git a/server/httpwritter.go b/server/httpwritter.go new file mode 100644 index 0000000..bde7452 --- /dev/null +++ b/server/httpwritter.go @@ -0,0 +1,250 @@ +package server + +import ( + "bytes" + "io" + "log" + "net" + "regexp" +) + +type HTTPWriter interface { + io.ReadWriteCloser + CloseWrite() error + RemoteAddr() net.Addr + UseResponseMiddleware(mw ResponseMiddleware) + UseRequestMiddleware(mw RequestMiddleware) + SetRequestHeader(header RequestHeaderManager) + RequestMiddlewares() []RequestMiddleware + ResponseMiddlewares() []ResponseMiddleware + ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error + ApplyRequestMiddlewares(reqhf RequestHeaderManager) error +} + +type httpWriter struct { + remoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + respHeader ResponseHeaderManager + reqHeader RequestHeaderManager + respMW []ResponseMiddleware + reqMW []RequestMiddleware +} + +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +func (hw *httpWriter) RemoteAddr() net.Addr { + return hw.remoteAddr +} + +func (hw *httpWriter) UseResponseMiddleware(mw ResponseMiddleware) { + hw.respMW = append(hw.respMW, mw) +} + +func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { + hw.reqMW = append(hw.reqMW, mw) +} + +func (hw *httpWriter) SetRequestHeader(header RequestHeaderManager) { + hw.reqHeader = header +} + +func (hw *httpWriter) RequestMiddlewares() []RequestMiddleware { + return hw.reqMW +} + +func (hw *httpWriter) ResponseMiddlewares() []ResponseMiddleware { + return hw.respMW +} +func (hw *httpWriter) Close() error { + return hw.writer.(io.Closer).Close() +} + +func (hw *httpWriter) CloseWrite() error { + if closer, ok := hw.writer.(interface{ CloseWrite() error }); ok { + return closer.CloseWrite() + } + return hw.Close() +} + +func (hw *httpWriter) Read(p []byte) (int, error) { + tmp := make([]byte, len(p)) + read, err := hw.reader.Read(tmp) + if read == 0 && err != nil { + return 0, err + } + + tmp = tmp[:read] + + headerEndIdx := bytes.Index(tmp, DELIMITER) + if headerEndIdx == -1 { + return hw.handleNoDelimiter(p, tmp, err) + } + + header, body := hw.splitHeaderAndBody(tmp, headerEndIdx) + + if !isHTTPHeader(header) { + copy(p, tmp) + return read, nil + } + + return hw.processHTTPRequest(p, header, body) +} + +func (hw *httpWriter) handleNoDelimiter(p, tmp []byte, err error) (int, error) { + copy(p, tmp) + return len(tmp), err +} + +func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) { + header := data[:delimiterIdx+len(DELIMITER)] + body := data[delimiterIdx+len(DELIMITER):] + return header, body +} + +func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { + reqhf, err := NewRequestHeaderFactory(header) + if err != nil { + return 0, err + } + + if err = hw.ApplyRequestMiddlewares(reqhf); err != nil { + return 0, err + } + + hw.reqHeader = reqhf + combined := append(reqhf.Finalize(), body...) + return copy(p, combined), nil +} + +func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeaderManager) error { + for _, m := range hw.RequestMiddlewares() { + if err := m.HandleRequest(reqhf); err != nil { + log.Printf("Error when applying request middleware: %v", err) + return err + } + } + return nil +} + +func (hw *httpWriter) Write(p []byte) (int, error) { + if hw.shouldBypassBuffering(p) { + hw.respHeader = nil + } + + if hw.respHeader != nil { + return hw.writer.Write(p) + } + + hw.buf = append(hw.buf, p...) + + headerEndIdx := bytes.Index(hw.buf, DELIMITER) + if headerEndIdx == -1 { + return len(p), nil + } + + return hw.processBufferedResponse(p, headerEndIdx) +} + +func (hw *httpWriter) shouldBypassBuffering(p []byte) bool { + return hw.respHeader != nil && len(hw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" +} + +func (hw *httpWriter) processBufferedResponse(p []byte, delimiterIdx int) (int, error) { + header, body := hw.splitHeaderAndBody(hw.buf, delimiterIdx) + + if !isHTTPHeader(header) { + return hw.writeRawBuffer() + } + + if err := hw.processHTTPResponse(header, body); err != nil { + return 0, err + } + + hw.buf = nil + return len(p), nil +} + +func (hw *httpWriter) writeRawBuffer() (int, error) { + _, err := hw.writer.Write(hw.buf) + length := len(hw.buf) + hw.buf = nil + if err != nil { + return 0, err + } + return length, nil +} + +func (hw *httpWriter) processHTTPResponse(header, body []byte) error { + resphf := NewResponseHeaderFactory(header) + + if err := hw.ApplyResponseMiddlewares(resphf, body); err != nil { + return err + } + + hw.respHeader = resphf + finalHeader := resphf.Finalize() + + if err := hw.writeHeaderAndBody(finalHeader, body); err != nil { + return err + } + + return nil +} + +func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error { + for _, m := range hw.ResponseMiddlewares() { + if err := m.HandleResponse(resphf, body); err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return err + } + } + return nil +} + +func (hw *httpWriter) writeHeaderAndBody(header, body []byte) error { + if _, err := hw.writer.Write(header); err != nil { + return err + } + + if len(body) > 0 { + if _, err := hw.writer.Write(body); err != nil { + return err + } + } + + return nil +} + +func NewHTTPWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { + return &httpWriter{ + remoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + } +} + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + colonIdx := bytes.IndexByte(line, ':') + if colonIdx <= 0 { + return false + } + } + return true +} diff --git a/server/server.go b/server/server.go index 792f47e..868b9e6 100644 --- a/server/server.go +++ b/server/server.go @@ -33,8 +33,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie log.Fatalf("failed to listen on port 2200: %v", err) return nil, err } + redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" - HttpServer := NewHTTPServer(sessionRegistry) + HttpServer := NewHTTPServer(sessionRegistry, redirectTLS) err = HttpServer.ListenAndServe() if err != nil { log.Fatalf("failed to start http server: %v", err) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index fcbc12f..fa1dff4 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -124,18 +124,6 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA if err != nil { log.Printf("Failed to discard connection: %v", err) } - - err = src.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing source channel: %v", err) - } - - if closer, ok := dst.(io.Closer); ok { - err = closer.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing destination connection: %v", err) - } - } }() log.Printf("Handling new forwarded connection from %s", remoteAddr) -- 2.49.1 From aa1a46517831f10b84de517362fa85ee1b754cbe Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 20 Jan 2026 19:01:15 +0700 Subject: [PATCH 04/11] refactor(forwarder): improve connection handling and cleanup - Extract copyAndClose method for bidirectional data transfe - Add closeWriter helper for graceful connection shutdown - Add handleIncomingConnection helper - Add openForwardedChannel helper --- session/forwarder/forwarder.go | 144 +++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 62 deletions(-) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index fa1dff4..cac6691 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "log" "net" @@ -62,6 +63,55 @@ type Forwarder interface { Close() error } +func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + resultChan := make(chan channelResult, 1) + + go func() { + channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + if channel != nil { + err = channel.Close() + if err != nil { + log.Printf("Failed to close unused channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + } + } + }() + + select { + case result := <-resultChan: + return result.channel, result.reqs, result.err + case <-time.After(5 * time.Second): + return nil, nil, errors.New("timeout opening forwarded-tcpip channel") + } +} + +func (f *forwarder) handleIncomingConnection(conn net.Conn) { + payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) + + channel, reqs, err := f.openForwardedChannel(payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + err = conn.Close() + if err != nil { + log.Printf("Failed to close connection: %v", err) + } + return + } + + go ssh.DiscardRequests(reqs) + go f.HandleConnection(conn, channel, conn.RemoteAddr()) +} + func (f *forwarder) AcceptTCPConnections() { for { conn, err := f.Listener().Accept() @@ -73,51 +123,33 @@ func (f *forwarder) AcceptTCPConnections() { continue } - payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) - select { - case resultChan <- channelResult{channel, reqs, err}: - default: - if channel != nil { - err := channel.Close() - if err != nil { - log.Printf("Failed to close unused channel: %v", err) - return - } - go ssh.DiscardRequests(reqs) - } - } - }() - - select { - case result := <-resultChan: - if result.err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - continue - } - go ssh.DiscardRequests(result.reqs) - go f.HandleConnection(conn, result.channel, conn.RemoteAddr()) - case <-time.After(5 * time.Second): - log.Printf("Timeout opening forwarded-tcpip channel") - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - } + go f.handleIncomingConnection(conn) } } +func closeWriter(w io.Writer) error { + if cw, ok := w.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + if closer, ok := w.(io.Closer); ok { + return closer.Close() + } + return nil +} + +func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error { + var errs []error + _, err := copyWithBuffer(dst, src) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err)) + } + + if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) { + errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err)) + } + return errors.Join(errs...) +} + func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { defer func() { _, err := io.Copy(io.Discard, src) @@ -133,31 +165,19 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA go func() { defer wg.Done() - _, err := copyWithBuffer(dst, src) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying src to dst: %v", err) - } - if conn, ok := dst.(interface{ CloseWrite() error }); ok { - if err = conn.CloseWrite(); err != nil { - log.Printf("Error closing write side of dst: %v", err) - } - } else { - if closer, closerOk := dst.(io.Closer); closerOk { - if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing dst connection: %v", err) - } - } + err := f.copyAndClose(dst, src, "src to dst") + if err != nil { + log.Println("Error during copy: ", err) + return } }() go func() { defer wg.Done() - _, err := copyWithBuffer(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying dst to src: %v", err) - } - if err = src.CloseWrite(); err != nil { - log.Printf("Error closing write side of src: %v", err) + err := f.copyAndClose(src, dst, "dst to src") + if err != nil { + log.Println("Error during copy: ", err) + return } }() -- 2.49.1 From e3ead4d52fdd37abcf2d1834171d08d016cda9d2 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 20 Jan 2026 19:07:47 +0700 Subject: [PATCH 05/11] refactor: optimize header parsing and remove factory naming - Remove factory naming - Use direct byte indexing instead of bytes.TrimRight - Extract parseStartLine and setRemainingHeaders helpers --- server/header.go | 244 +++++++++++++++++++----------------------- server/http.go | 10 +- server/httpwritter.go | 27 ++--- server/middleware.go | 8 +- 4 files changed, 135 insertions(+), 154 deletions(-) diff --git a/server/header.go b/server/header.go index 584394b..bc3ce73 100644 --- a/server/header.go +++ b/server/header.go @@ -6,22 +6,20 @@ import ( "fmt" ) -type HeaderManager interface { - Get(key string) []byte - Set(key string, value []byte) - Remove(key string) - Finalize() []byte -} - -type ResponseHeaderManager interface { - Get(key string) string +type ResponseHeader interface { + Value(key string) string Set(key string, value string) Remove(key string) Finalize() []byte } -type RequestHeaderManager interface { - Get(key string) string +type responseHeader struct { + startLine []byte + headers map[string]string +} + +type RequestHeader interface { + Value(key string) string Set(key string, value string) Remove(key string) Finalize() []byte @@ -29,13 +27,7 @@ type RequestHeaderManager interface { GetPath() string GetVersion() string } - -type responseHeaderFactory struct { - startLine []byte - headers map[string]string -} - -type requestHeaderFactory struct { +type requestHeader struct { method string path string version string @@ -43,7 +35,7 @@ type requestHeaderFactory struct { headers map[string]string } -func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) { +func NewRequestHeader(r interface{}) (RequestHeader, error) { switch v := r.(type) { case []byte: return parseHeadersFromBytes(v) @@ -54,38 +46,16 @@ func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) { } } -func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) { - header := &requestHeaderFactory{ - headers: make(map[string]string, 16), - } - - lineEnd := bytes.IndexByte(headerData, '\n') - if lineEnd == -1 { - return nil, fmt.Errorf("invalid request: no newline found") - } - - startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n") - header.startLine = make([]byte, len(startLine)) - copy(header.startLine, startLine) - - parts := bytes.Split(startLine, []byte{' '}) - if len(parts) < 3 { - return nil, fmt.Errorf("invalid request line") - } - - header.method = string(parts[0]) - header.path = string(parts[1]) - header.version = string(parts[2]) - - remaining := headerData[lineEnd+1:] - +func setRemainingHeaders(remaining []byte, header interface { + Set(key string, value string) +}) { for len(remaining) > 0 { - lineEnd = bytes.IndexByte(remaining, '\n') + lineEnd := bytes.Index(remaining, []byte("\r\n")) if lineEnd == -1 { lineEnd = len(remaining) } - line := bytes.TrimRight(remaining[:lineEnd], "\r\n") + line := remaining[:lineEnd] if len(line) == 0 { break @@ -95,63 +65,84 @@ func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) { if colonIdx != -1 { key := bytes.TrimSpace(line[:colonIdx]) value := bytes.TrimSpace(line[colonIdx+1:]) - header.headers[string(key)] = string(value) + header.Set(string(key), string(value)) } if lineEnd == len(remaining) { break } - remaining = remaining[lineEnd+1:] + + remaining = remaining[lineEnd+2:] } +} + +func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) { + header := &requestHeader{ + headers: make(map[string]string, 16), + } + + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no CRLF found in start line") + } + + startLine := headerData[:lineEnd] + header.startLine = startLine + var err error + header.method, header.path, header.version, err = parseStartLine(startLine) + if err != nil { + return nil, err + } + + remaining := headerData[lineEnd+2:] + + setRemainingHeaders(remaining, header) return header, nil } -func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) { - header := &requestHeaderFactory{ +func parseStartLine(startLine []byte) (method, path, version string, err error) { + firstSpace := bytes.IndexByte(startLine, ' ') + if firstSpace == -1 { + return "", "", "", fmt.Errorf("invalid start line: missing method") + } + + secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ') + if secondSpace == -1 { + return "", "", "", fmt.Errorf("invalid start line: missing version") + } + secondSpace += firstSpace + 1 + + method = string(startLine[:firstSpace]) + path = string(startLine[firstSpace+1 : secondSpace]) + version = string(startLine[secondSpace+1:]) + + return method, path, version, nil +} + +func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) { + header := &requestHeader{ headers: make(map[string]string, 16), } startLineBytes, err := br.ReadSlice('\n') if err != nil { - if err == bufio.ErrBufferFull { - var startLine string - startLine, err = br.ReadString('\n') - if err != nil { - return nil, err - } - startLineBytes = []byte(startLine) - } else { - return nil, err - } + return nil, err } startLineBytes = bytes.TrimRight(startLineBytes, "\r\n") header.startLine = make([]byte, len(startLineBytes)) copy(header.startLine, startLineBytes) - parts := bytes.Split(startLineBytes, []byte{' '}) - if len(parts) < 3 { - return nil, fmt.Errorf("invalid request line") + header.method, header.path, header.version, err = parseStartLine(header.startLine) + if err != nil { + return nil, err } - header.method = string(parts[0]) - header.path = string(parts[1]) - header.version = string(parts[2]) - for { lineBytes, err := br.ReadSlice('\n') if err != nil { - if err == bufio.ErrBufferFull { - var line string - line, err = br.ReadString('\n') - if err != nil { - return nil, err - } - lineBytes = []byte(line) - } else { - return nil, err - } + return nil, err } lineBytes = bytes.TrimRight(lineBytes, "\r\n") @@ -174,63 +165,63 @@ func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) { return header, nil } -func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager { - header := &responseHeaderFactory{ +func NewResponseHeader(headerData []byte) (ResponseHeader, error) { + header := &responseHeader{ startLine: nil, - headers: make(map[string]string), + headers: make(map[string]string, 16), } - lines := bytes.Split(startLine, []byte("\r\n")) - if len(lines) == 0 { - return header - } - header.startLine = lines[0] - for _, h := range lines[1:] { - if len(h) == 0 { - continue - } - parts := bytes.SplitN(h, []byte(":"), 2) - if len(parts) < 2 { - continue - } - - key := parts[0] - val := bytes.TrimSpace(parts[1]) - header.headers[string(key)] = string(val) + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no CRLF found in start line") } - return header + + header.startLine = headerData[:lineEnd] + remaining := headerData[lineEnd+2:] + setRemainingHeaders(remaining, header) + + return header, nil } -func (resp *responseHeaderFactory) Get(key string) string { +func (resp *responseHeader) Value(key string) string { return resp.headers[key] } -func (resp *responseHeaderFactory) Set(key string, value string) { +func (resp *responseHeader) Set(key string, value string) { resp.headers[key] = value } -func (resp *responseHeaderFactory) Remove(key string) { +func (resp *responseHeader) Remove(key string) { delete(resp.headers, key) } -func (resp *responseHeaderFactory) Finalize() []byte { - var buf bytes.Buffer +func finalize(startLine []byte, headers map[string]string) []byte { + size := len(startLine) + 2 + for key, val := range headers { + size += len(key) + 2 + len(val) + 2 + } + size += 2 - buf.Write(resp.startLine) - buf.WriteString("\r\n") + buf := make([]byte, 0, size) + buf = append(buf, startLine...) + buf = append(buf, '\r', '\n') - for key, val := range resp.headers { - buf.WriteString(key) - buf.WriteString(": ") - buf.WriteString(val) - buf.WriteString("\r\n") + for key, val := range headers { + buf = append(buf, key...) + buf = append(buf, ':', ' ') + buf = append(buf, val...) + buf = append(buf, '\r', '\n') } - buf.WriteString("\r\n") - return buf.Bytes() + buf = append(buf, '\r', '\n') + return buf } -func (req *requestHeaderFactory) Get(key string) string { +func (resp *responseHeader) Finalize() []byte { + return finalize(resp.startLine, resp.headers) +} + +func (req *requestHeader) Value(key string) string { val, ok := req.headers[key] if !ok { return "" @@ -238,39 +229,26 @@ func (req *requestHeaderFactory) Get(key string) string { return val } -func (req *requestHeaderFactory) Set(key string, value string) { +func (req *requestHeader) Set(key string, value string) { req.headers[key] = value } -func (req *requestHeaderFactory) Remove(key string) { +func (req *requestHeader) Remove(key string) { delete(req.headers, key) } -func (req *requestHeaderFactory) GetMethod() string { +func (req *requestHeader) GetMethod() string { return req.method } -func (req *requestHeaderFactory) GetPath() string { +func (req *requestHeader) GetPath() string { return req.path } -func (req *requestHeaderFactory) GetVersion() string { +func (req *requestHeader) GetVersion() string { return req.version } -func (req *requestHeaderFactory) Finalize() []byte { - var buf bytes.Buffer - - buf.Write(req.startLine) - buf.WriteString("\r\n") - - for key, val := range req.headers { - buf.WriteString(key) - buf.WriteString(": ") - buf.WriteString(val) - buf.WriteString("\r\n") - } - - buf.WriteString("\r\n") - return buf.Bytes() +func (req *requestHeader) Finalize() []byte { + return finalize(req.startLine, req.headers) } diff --git a/server/http.go b/server/http.go index f685d39..e8da8a6 100644 --- a/server/http.go +++ b/server/http.go @@ -112,7 +112,7 @@ func (hs *httpServer) handler(conn net.Conn, isTLS bool) { defer hs.closeConnection(conn) dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeaderFactory(dstReader) + reqhf, err := NewRequestHeader(dstReader) if err != nil { log.Printf("Error creating request header: %v", err) return @@ -150,8 +150,8 @@ func (hs *httpServer) closeConnection(conn net.Conn) { } } -func (hs *httpServer) extractSlug(reqhf RequestHeaderManager) (string, error) { - host := strings.Split(reqhf.Get("Host"), ".") +func (hs *httpServer) extractSlug(reqhf RequestHeader) (string, error) { + host := strings.Split(reqhf.Value("Host"), ".") if len(host) < 1 { return "", errors.New("invalid host") } @@ -193,7 +193,7 @@ func (hs *httpServer) getSession(slug string) (session.Session, error) { return sshSession, nil } -func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { +func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeader, sshSession session.Session) { channel, err := hs.openForwardedChannel(hw, sshSession) if err != nil { log.Printf("Failed to establish channel: %v", err) @@ -260,7 +260,7 @@ func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { hw.UseRequestMiddleware(forwardedForMiddleware) } -func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeaderManager, channel ssh.Channel) error { +func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeader, channel ssh.Channel) error { hw.SetRequestHeader(initialRequest) if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { diff --git a/server/httpwritter.go b/server/httpwritter.go index bde7452..64154d0 100644 --- a/server/httpwritter.go +++ b/server/httpwritter.go @@ -14,11 +14,11 @@ type HTTPWriter interface { RemoteAddr() net.Addr UseResponseMiddleware(mw ResponseMiddleware) UseRequestMiddleware(mw RequestMiddleware) - SetRequestHeader(header RequestHeaderManager) + SetRequestHeader(header RequestHeader) RequestMiddlewares() []RequestMiddleware ResponseMiddlewares() []ResponseMiddleware - ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error - ApplyRequestMiddlewares(reqhf RequestHeaderManager) error + ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error + ApplyRequestMiddlewares(reqhf RequestHeader) error } type httpWriter struct { @@ -27,8 +27,8 @@ type httpWriter struct { reader io.Reader headerBuf []byte buf []byte - respHeader ResponseHeaderManager - reqHeader RequestHeaderManager + respHeader ResponseHeader + reqHeader RequestHeader respMW []ResponseMiddleware reqMW []RequestMiddleware } @@ -49,7 +49,7 @@ func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { hw.reqMW = append(hw.reqMW, mw) } -func (hw *httpWriter) SetRequestHeader(header RequestHeaderManager) { +func (hw *httpWriter) SetRequestHeader(header RequestHeader) { hw.reqHeader = header } @@ -107,7 +107,7 @@ func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, } func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { - reqhf, err := NewRequestHeaderFactory(header) + reqhf, err := NewRequestHeader(header) if err != nil { return 0, err } @@ -121,7 +121,7 @@ func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { return copy(p, combined), nil } -func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeaderManager) error { +func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeader) error { for _, m := range hw.RequestMiddlewares() { if err := m.HandleRequest(reqhf); err != nil { log.Printf("Error when applying request middleware: %v", err) @@ -180,23 +180,26 @@ func (hw *httpWriter) writeRawBuffer() (int, error) { } func (hw *httpWriter) processHTTPResponse(header, body []byte) error { - resphf := NewResponseHeaderFactory(header) + resphf, err := NewResponseHeader(header) + if err != nil { + return err + } - if err := hw.ApplyResponseMiddlewares(resphf, body); err != nil { + if err = hw.ApplyResponseMiddlewares(resphf, body); err != nil { return err } hw.respHeader = resphf finalHeader := resphf.Finalize() - if err := hw.writeHeaderAndBody(finalHeader, body); err != nil { + if err = hw.writeHeaderAndBody(finalHeader, body); err != nil { return err } return nil } -func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error { +func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error { for _, m := range hw.ResponseMiddlewares() { if err := m.HandleResponse(resphf, body); err != nil { log.Printf("Cannot apply middleware: %s\n", err) diff --git a/server/middleware.go b/server/middleware.go index ee6ca1a..63b2467 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -5,11 +5,11 @@ import ( ) type RequestMiddleware interface { - HandleRequest(header RequestHeaderManager) error + HandleRequest(header RequestHeader) error } type ResponseMiddleware interface { - HandleResponse(header ResponseHeaderManager, body []byte) error + HandleResponse(header ResponseHeader, body []byte) error } type TunnelFingerprint struct{} @@ -18,7 +18,7 @@ func NewTunnelFingerprint() *TunnelFingerprint { return &TunnelFingerprint{} } -func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error { +func (h *TunnelFingerprint) HandleResponse(header ResponseHeader, body []byte) error { header.Set("Server", "Tunnel Please") return nil } @@ -31,7 +31,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor { return &ForwardedFor{addr: addr} } -func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error { +func (ff *ForwardedFor) HandleRequest(header RequestHeader) error { host, _, err := net.SplitHostPort(ff.addr.String()) if err != nil { return err -- 2.49.1 From 9a4539cc02d42eb408339454059d57e66bf7ed54 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 20 Jan 2026 21:15:34 +0700 Subject: [PATCH 06/11] refactor(httpheader): extract header parsing into dedicated package Moved HTTP header parsing and building logic from server package to internal/httpheader --- internal/httpheader/header.go | 30 +++++ .../httpheader/parser.go | 108 +----------------- internal/httpheader/request.go | 49 ++++++++ internal/httpheader/response.go | 40 +++++++ server/http.go | 9 +- server/httpwritter.go | 21 ++-- server/middleware.go | 9 +- 7 files changed, 141 insertions(+), 125 deletions(-) create mode 100644 internal/httpheader/header.go rename server/header.go => internal/httpheader/parser.go (59%) create mode 100644 internal/httpheader/request.go create mode 100644 internal/httpheader/response.go diff --git a/internal/httpheader/header.go b/internal/httpheader/header.go new file mode 100644 index 0000000..ccd1bed --- /dev/null +++ b/internal/httpheader/header.go @@ -0,0 +1,30 @@ +package httpheader + +type ResponseHeader interface { + Value(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte +} + +type responseHeader struct { + startLine []byte + headers map[string]string +} + +type RequestHeader interface { + Value(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte + GetMethod() string + GetPath() string + GetVersion() string +} +type requestHeader struct { + method string + path string + version string + startLine []byte + headers map[string]string +} diff --git a/server/header.go b/internal/httpheader/parser.go similarity index 59% rename from server/header.go rename to internal/httpheader/parser.go index bc3ce73..3325ae5 100644 --- a/server/header.go +++ b/internal/httpheader/parser.go @@ -1,4 +1,4 @@ -package server +package httpheader import ( "bufio" @@ -6,46 +6,6 @@ import ( "fmt" ) -type ResponseHeader interface { - Value(key string) string - Set(key string, value string) - Remove(key string) - Finalize() []byte -} - -type responseHeader struct { - startLine []byte - headers map[string]string -} - -type RequestHeader interface { - Value(key string) string - Set(key string, value string) - Remove(key string) - Finalize() []byte - GetMethod() string - GetPath() string - GetVersion() string -} -type requestHeader struct { - method string - path string - version string - startLine []byte - headers map[string]string -} - -func NewRequestHeader(r interface{}) (RequestHeader, error) { - switch v := r.(type) { - case []byte: - return parseHeadersFromBytes(v) - case *bufio.Reader: - return parseHeadersFromReader(v) - default: - return nil, fmt.Errorf("unsupported type: %T", r) - } -} - func setRemainingHeaders(remaining []byte, header interface { Set(key string, value string) }) { @@ -165,36 +125,6 @@ func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) { return header, nil } -func NewResponseHeader(headerData []byte) (ResponseHeader, error) { - header := &responseHeader{ - startLine: nil, - headers: make(map[string]string, 16), - } - - lineEnd := bytes.Index(headerData, []byte("\r\n")) - if lineEnd == -1 { - return nil, fmt.Errorf("invalid request: no CRLF found in start line") - } - - header.startLine = headerData[:lineEnd] - remaining := headerData[lineEnd+2:] - setRemainingHeaders(remaining, header) - - return header, nil -} - -func (resp *responseHeader) Value(key string) string { - return resp.headers[key] -} - -func (resp *responseHeader) Set(key string, value string) { - resp.headers[key] = value -} - -func (resp *responseHeader) Remove(key string) { - delete(resp.headers, key) -} - func finalize(startLine []byte, headers map[string]string) []byte { size := len(startLine) + 2 for key, val := range headers { @@ -216,39 +146,3 @@ func finalize(startLine []byte, headers map[string]string) []byte { buf = append(buf, '\r', '\n') return buf } - -func (resp *responseHeader) Finalize() []byte { - return finalize(resp.startLine, resp.headers) -} - -func (req *requestHeader) Value(key string) string { - val, ok := req.headers[key] - if !ok { - return "" - } - return val -} - -func (req *requestHeader) Set(key string, value string) { - req.headers[key] = value -} - -func (req *requestHeader) Remove(key string) { - delete(req.headers, key) -} - -func (req *requestHeader) GetMethod() string { - return req.method -} - -func (req *requestHeader) GetPath() string { - return req.path -} - -func (req *requestHeader) GetVersion() string { - return req.version -} - -func (req *requestHeader) Finalize() []byte { - return finalize(req.startLine, req.headers) -} diff --git a/internal/httpheader/request.go b/internal/httpheader/request.go new file mode 100644 index 0000000..ae63340 --- /dev/null +++ b/internal/httpheader/request.go @@ -0,0 +1,49 @@ +package httpheader + +import ( + "bufio" + "fmt" +) + +func NewRequestHeader(r interface{}) (RequestHeader, error) { + switch v := r.(type) { + case []byte: + return parseHeadersFromBytes(v) + case *bufio.Reader: + return parseHeadersFromReader(v) + default: + return nil, fmt.Errorf("unsupported type: %T", r) + } +} + +func (req *requestHeader) Value(key string) string { + val, ok := req.headers[key] + if !ok { + return "" + } + return val +} + +func (req *requestHeader) Set(key string, value string) { + req.headers[key] = value +} + +func (req *requestHeader) Remove(key string) { + delete(req.headers, key) +} + +func (req *requestHeader) GetMethod() string { + return req.method +} + +func (req *requestHeader) GetPath() string { + return req.path +} + +func (req *requestHeader) GetVersion() string { + return req.version +} + +func (req *requestHeader) Finalize() []byte { + return finalize(req.startLine, req.headers) +} diff --git a/internal/httpheader/response.go b/internal/httpheader/response.go new file mode 100644 index 0000000..63ad352 --- /dev/null +++ b/internal/httpheader/response.go @@ -0,0 +1,40 @@ +package httpheader + +import ( + "bytes" + "fmt" +) + +func NewResponseHeader(headerData []byte) (ResponseHeader, error) { + header := &responseHeader{ + startLine: nil, + headers: make(map[string]string, 16), + } + + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid response: no CRLF found in start line") + } + + header.startLine = headerData[:lineEnd] + remaining := headerData[lineEnd+2:] + setRemainingHeaders(remaining, header) + + return header, nil +} + +func (resp *responseHeader) Value(key string) string { + return resp.headers[key] +} + +func (resp *responseHeader) Set(key string, value string) { + resp.headers[key] = value +} + +func (resp *responseHeader) Remove(key string) { + delete(resp.headers, key) +} + +func (resp *responseHeader) Finalize() []byte { + return finalize(resp.startLine, resp.headers) +} diff --git a/server/http.go b/server/http.go index e8da8a6..a36b6b1 100644 --- a/server/http.go +++ b/server/http.go @@ -11,6 +11,7 @@ import ( "strings" "time" "tunnel_pls/internal/config" + "tunnel_pls/internal/httpheader" "tunnel_pls/session" "tunnel_pls/types" @@ -112,7 +113,7 @@ func (hs *httpServer) handler(conn net.Conn, isTLS bool) { defer hs.closeConnection(conn) dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeader(dstReader) + reqhf, err := httpheader.NewRequestHeader(dstReader) if err != nil { log.Printf("Error creating request header: %v", err) return @@ -150,7 +151,7 @@ func (hs *httpServer) closeConnection(conn net.Conn) { } } -func (hs *httpServer) extractSlug(reqhf RequestHeader) (string, error) { +func (hs *httpServer) extractSlug(reqhf httpheader.RequestHeader) (string, error) { host := strings.Split(reqhf.Value("Host"), ".") if len(host) < 1 { return "", errors.New("invalid host") @@ -193,7 +194,7 @@ func (hs *httpServer) getSession(slug string) (session.Session, error) { return sshSession, nil } -func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeader, sshSession session.Session) { +func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, sshSession session.Session) { channel, err := hs.openForwardedChannel(hw, sshSession) if err != nil { log.Printf("Failed to establish channel: %v", err) @@ -260,7 +261,7 @@ func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { hw.UseRequestMiddleware(forwardedForMiddleware) } -func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeader, channel ssh.Channel) error { +func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, channel ssh.Channel) error { hw.SetRequestHeader(initialRequest) if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { diff --git a/server/httpwritter.go b/server/httpwritter.go index 64154d0..9d52f24 100644 --- a/server/httpwritter.go +++ b/server/httpwritter.go @@ -6,6 +6,7 @@ import ( "log" "net" "regexp" + "tunnel_pls/internal/httpheader" ) type HTTPWriter interface { @@ -14,11 +15,11 @@ type HTTPWriter interface { RemoteAddr() net.Addr UseResponseMiddleware(mw ResponseMiddleware) UseRequestMiddleware(mw RequestMiddleware) - SetRequestHeader(header RequestHeader) + SetRequestHeader(header httpheader.RequestHeader) RequestMiddlewares() []RequestMiddleware ResponseMiddlewares() []ResponseMiddleware - ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error - ApplyRequestMiddlewares(reqhf RequestHeader) error + ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error + ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error } type httpWriter struct { @@ -27,8 +28,8 @@ type httpWriter struct { reader io.Reader headerBuf []byte buf []byte - respHeader ResponseHeader - reqHeader RequestHeader + respHeader httpheader.ResponseHeader + reqHeader httpheader.RequestHeader respMW []ResponseMiddleware reqMW []RequestMiddleware } @@ -49,7 +50,7 @@ func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { hw.reqMW = append(hw.reqMW, mw) } -func (hw *httpWriter) SetRequestHeader(header RequestHeader) { +func (hw *httpWriter) SetRequestHeader(header httpheader.RequestHeader) { hw.reqHeader = header } @@ -107,7 +108,7 @@ func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, } func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { - reqhf, err := NewRequestHeader(header) + reqhf, err := httpheader.NewRequestHeader(header) if err != nil { return 0, err } @@ -121,7 +122,7 @@ func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { return copy(p, combined), nil } -func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeader) error { +func (hw *httpWriter) ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error { for _, m := range hw.RequestMiddlewares() { if err := m.HandleRequest(reqhf); err != nil { log.Printf("Error when applying request middleware: %v", err) @@ -180,7 +181,7 @@ func (hw *httpWriter) writeRawBuffer() (int, error) { } func (hw *httpWriter) processHTTPResponse(header, body []byte) error { - resphf, err := NewResponseHeader(header) + resphf, err := httpheader.NewResponseHeader(header) if err != nil { return err } @@ -199,7 +200,7 @@ func (hw *httpWriter) processHTTPResponse(header, body []byte) error { return nil } -func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error { +func (hw *httpWriter) ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error { for _, m := range hw.ResponseMiddlewares() { if err := m.HandleResponse(resphf, body); err != nil { log.Printf("Cannot apply middleware: %s\n", err) diff --git a/server/middleware.go b/server/middleware.go index 63b2467..6f50c4c 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -2,14 +2,15 @@ package server import ( "net" + "tunnel_pls/internal/httpheader" ) type RequestMiddleware interface { - HandleRequest(header RequestHeader) error + HandleRequest(header httpheader.RequestHeader) error } type ResponseMiddleware interface { - HandleResponse(header ResponseHeader, body []byte) error + HandleResponse(header httpheader.ResponseHeader, body []byte) error } type TunnelFingerprint struct{} @@ -18,7 +19,7 @@ func NewTunnelFingerprint() *TunnelFingerprint { return &TunnelFingerprint{} } -func (h *TunnelFingerprint) HandleResponse(header ResponseHeader, body []byte) error { +func (h *TunnelFingerprint) HandleResponse(header httpheader.ResponseHeader, body []byte) error { header.Set("Server", "Tunnel Please") return nil } @@ -31,7 +32,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor { return &ForwardedFor{addr: addr} } -func (ff *ForwardedFor) HandleRequest(header RequestHeader) error { +func (ff *ForwardedFor) HandleRequest(header httpheader.RequestHeader) error { host, _, err := net.SplitHostPort(ff.addr.String()) if err != nil { return err -- 2.49.1 From 1e12373359384c88df29c1f49f00ed6e482f231c Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 21 Jan 2026 14:06:46 +0700 Subject: [PATCH 07/11] chore(restructure): reorganize project layout - Reorganize internal packages and overall project structure - Update imports and wiring to match the new layout - Separate HTTP parsing and streaming from the server package - Separate middleware from the server package - Separate session registry from the session package - Move HTTP, HTTPS, and TCP servers to the transport package - Session package no longer starts the TCP server directly - Server package no longer starts HTTP/HTTPS servers on initialization - Forwarder no longer handles accepting TCP requests - Move session details to the types package - HTTP/HTTPS initialization is now the responsibility of main --- internal/grpc/client/client.go | 7 +- .../{httpheader => http/header}/header.go | 2 +- .../{httpheader => http/header}/parser.go | 2 +- .../{httpheader => http/header}/request.go | 4 +- .../{httpheader => http/header}/response.go | 4 +- internal/http/stream/parser.go | 29 ++ internal/http/stream/reader.go | 50 ++++ internal/http/stream/stream.go | 103 +++++++ internal/http/stream/writer.go | 88 ++++++ internal/middleware/forwardedfor.go | 23 ++ internal/middleware/middleware.go | 13 + internal/middleware/tunnelfingerprint.go | 16 + internal/port/port.go | 38 +-- {session => internal/registry}/registry.go | 28 +- internal/transport/http.go | 40 +++ internal/transport/httphandler.go | 227 ++++++++++++++ internal/transport/https.go | 48 +++ internal/transport/tcp.go | 66 +++++ {server => internal/transport}/tls.go | 2 +- internal/transport/transport.go | 10 + {version => internal/version}/version.go | 0 main.go | 54 +++- server/http.go | 276 ------------------ server/httpwritter.go | 254 ---------------- server/middleware.go | 42 --- server/server.go | 39 +-- session/forwarder/forwarder.go | 44 +-- session/lifecycle/lifecycle.go | 10 +- session/session.go | 37 +-- types/types.go | 10 + 30 files changed, 862 insertions(+), 704 deletions(-) rename internal/{httpheader => http/header}/header.go (96%) rename internal/{httpheader => http/header}/parser.go (99%) rename internal/{httpheader => http/header}/request.go (90%) rename internal/{httpheader => http/header}/response.go (89%) create mode 100644 internal/http/stream/parser.go create mode 100644 internal/http/stream/reader.go create mode 100644 internal/http/stream/stream.go create mode 100644 internal/http/stream/writer.go create mode 100644 internal/middleware/forwardedfor.go create mode 100644 internal/middleware/middleware.go create mode 100644 internal/middleware/tunnelfingerprint.go rename {session => internal/registry}/registry.go (90%) create mode 100644 internal/transport/http.go create mode 100644 internal/transport/httphandler.go create mode 100644 internal/transport/https.go create mode 100644 internal/transport/tcp.go rename {server => internal/transport}/tls.go (99%) create mode 100644 internal/transport/transport.go rename {version => internal/version}/version.go (100%) delete mode 100644 server/http.go delete mode 100644 server/httpwritter.go delete mode 100644 server/middleware.go diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index ffcc7d9..0874afe 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -8,10 +8,9 @@ import ( "log" "time" "tunnel_pls/internal/config" + "tunnel_pls/internal/registry" "tunnel_pls/types" - "tunnel_pls/session" - proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -32,13 +31,13 @@ type Client interface { type client struct { conn *grpc.ClientConn address string - sessionRegistry session.Registry + sessionRegistry registry.Registry eventService proto.EventServiceClient authorizeConnectionService proto.UserServiceClient closing bool } -func New(address string, sessionRegistry session.Registry) (Client, error) { +func New(address string, sessionRegistry registry.Registry) (Client, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/internal/httpheader/header.go b/internal/http/header/header.go similarity index 96% rename from internal/httpheader/header.go rename to internal/http/header/header.go index ccd1bed..a5e52b3 100644 --- a/internal/httpheader/header.go +++ b/internal/http/header/header.go @@ -1,4 +1,4 @@ -package httpheader +package header type ResponseHeader interface { Value(key string) string diff --git a/internal/httpheader/parser.go b/internal/http/header/parser.go similarity index 99% rename from internal/httpheader/parser.go rename to internal/http/header/parser.go index 3325ae5..861c49e 100644 --- a/internal/httpheader/parser.go +++ b/internal/http/header/parser.go @@ -1,4 +1,4 @@ -package httpheader +package header import ( "bufio" diff --git a/internal/httpheader/request.go b/internal/http/header/request.go similarity index 90% rename from internal/httpheader/request.go rename to internal/http/header/request.go index ae63340..b05f699 100644 --- a/internal/httpheader/request.go +++ b/internal/http/header/request.go @@ -1,11 +1,11 @@ -package httpheader +package header import ( "bufio" "fmt" ) -func NewRequestHeader(r interface{}) (RequestHeader, error) { +func NewRequest(r interface{}) (RequestHeader, error) { switch v := r.(type) { case []byte: return parseHeadersFromBytes(v) diff --git a/internal/httpheader/response.go b/internal/http/header/response.go similarity index 89% rename from internal/httpheader/response.go rename to internal/http/header/response.go index 63ad352..b6305d4 100644 --- a/internal/httpheader/response.go +++ b/internal/http/header/response.go @@ -1,11 +1,11 @@ -package httpheader +package header import ( "bytes" "fmt" ) -func NewResponseHeader(headerData []byte) (ResponseHeader, error) { +func NewResponse(headerData []byte) (ResponseHeader, error) { header := &responseHeader{ startLine: nil, headers: make(map[string]string, 16), diff --git a/internal/http/stream/parser.go b/internal/http/stream/parser.go new file mode 100644 index 0000000..b1d8277 --- /dev/null +++ b/internal/http/stream/parser.go @@ -0,0 +1,29 @@ +package stream + +import "bytes" + +func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) { + headerByte := data[:delimiterIdx+len(DELIMITER)] + body := data[delimiterIdx+len(DELIMITER):] + return headerByte, body +} + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + colonIdx := bytes.IndexByte(line, ':') + if colonIdx <= 0 { + return false + } + } + return true +} diff --git a/internal/http/stream/reader.go b/internal/http/stream/reader.go new file mode 100644 index 0000000..f7c99d4 --- /dev/null +++ b/internal/http/stream/reader.go @@ -0,0 +1,50 @@ +package stream + +import ( + "bytes" + "tunnel_pls/internal/http/header" +) + +func (hs *http) Read(p []byte) (int, error) { + tmp := make([]byte, len(p)) + read, err := hs.reader.Read(tmp) + if read == 0 && err != nil { + return 0, err + } + + tmp = tmp[:read] + + headerEndIdx := bytes.Index(tmp, DELIMITER) + if headerEndIdx == -1 { + return handleNoDelimiter(p, tmp, err) + } + + headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx) + + if !isHTTPHeader(headerByte) { + copy(p, tmp) + return read, nil + } + + return hs.processHTTPRequest(p, headerByte, bodyByte) +} + +func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) { + reqhf, err := header.NewRequest(headerByte) + if err != nil { + return 0, err + } + + if err = hs.ApplyRequestMiddlewares(reqhf); err != nil { + return 0, err + } + + hs.reqHeader = reqhf + combined := append(reqhf.Finalize(), bodyByte...) + return copy(p, combined), nil +} + +func handleNoDelimiter(p, tmp []byte, err error) (int, error) { + copy(p, tmp) + return len(tmp), err +} diff --git a/internal/http/stream/stream.go b/internal/http/stream/stream.go new file mode 100644 index 0000000..97d2752 --- /dev/null +++ b/internal/http/stream/stream.go @@ -0,0 +1,103 @@ +package stream + +import ( + "io" + "log" + "net" + "regexp" + "tunnel_pls/internal/http/header" + "tunnel_pls/internal/middleware" +) + +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +type HTTP interface { + io.ReadWriteCloser + CloseWrite() error + RemoteAddr() net.Addr + UseResponseMiddleware(mw middleware.ResponseMiddleware) + UseRequestMiddleware(mw middleware.RequestMiddleware) + SetRequestHeader(header header.RequestHeader) + RequestMiddlewares() []middleware.RequestMiddleware + ResponseMiddlewares() []middleware.ResponseMiddleware + ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error + ApplyRequestMiddlewares(reqhf header.RequestHeader) error +} + +type http struct { + remoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + respHeader header.ResponseHeader + reqHeader header.RequestHeader + respMW []middleware.ResponseMiddleware + reqMW []middleware.RequestMiddleware +} + +func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP { + return &http{ + remoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + } +} + +func (hs *http) RemoteAddr() net.Addr { + return hs.remoteAddr +} + +func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) { + hs.respMW = append(hs.respMW, mw) +} + +func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) { + hs.reqMW = append(hs.reqMW, mw) +} + +func (hs *http) SetRequestHeader(header header.RequestHeader) { + hs.reqHeader = header +} + +func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware { + return hs.reqMW +} + +func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware { + return hs.respMW +} + +func (hs *http) Close() error { + return hs.writer.(io.Closer).Close() +} + +func (hs *http) CloseWrite() error { + if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok { + return closer.CloseWrite() + } + return hs.Close() +} + +func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error { + for _, m := range hs.RequestMiddlewares() { + if err := m.HandleRequest(reqhf); err != nil { + log.Printf("Error when applying request middleware: %v", err) + return err + } + } + return nil +} + +func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error { + for _, m := range hs.ResponseMiddlewares() { + if err := m.HandleResponse(resphf, bodyByte); err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return err + } + } + return nil +} diff --git a/internal/http/stream/writer.go b/internal/http/stream/writer.go new file mode 100644 index 0000000..05e438b --- /dev/null +++ b/internal/http/stream/writer.go @@ -0,0 +1,88 @@ +package stream + +import ( + "bytes" + "tunnel_pls/internal/http/header" +) + +func (hs *http) Write(p []byte) (int, error) { + if hs.shouldBypassBuffering(p) { + hs.respHeader = nil + } + + if hs.respHeader != nil { + return hs.writer.Write(p) + } + + hs.buf = append(hs.buf, p...) + + headerEndIdx := bytes.Index(hs.buf, DELIMITER) + if headerEndIdx == -1 { + return len(p), nil + } + + return hs.processBufferedResponse(p, headerEndIdx) +} + +func (hs *http) shouldBypassBuffering(p []byte) bool { + return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" +} + +func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) { + headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx) + + if !isHTTPHeader(headerByte) { + return hs.writeRawBuffer() + } + + if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil { + return 0, err + } + + hs.buf = nil + return len(p), nil +} + +func (hs *http) writeRawBuffer() (int, error) { + _, err := hs.writer.Write(hs.buf) + length := len(hs.buf) + hs.buf = nil + if err != nil { + return 0, err + } + return length, nil +} + +func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error { + resphf, err := header.NewResponse(headerByte) + if err != nil { + return err + } + + if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil { + return err + } + + hs.respHeader = resphf + finalHeader := resphf.Finalize() + + if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil { + return err + } + + return nil +} + +func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error { + if _, err := hs.writer.Write(header); err != nil { + return err + } + + if len(bodyByte) > 0 { + if _, err := hs.writer.Write(bodyByte); err != nil { + return err + } + } + + return nil +} diff --git a/internal/middleware/forwardedfor.go b/internal/middleware/forwardedfor.go new file mode 100644 index 0000000..6a744a6 --- /dev/null +++ b/internal/middleware/forwardedfor.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net" + "tunnel_pls/internal/http/header" +) + +type ForwardedFor struct { + addr net.Addr +} + +func NewForwardedFor(addr net.Addr) *ForwardedFor { + return &ForwardedFor{addr: addr} +} + +func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error { + host, _, err := net.SplitHostPort(ff.addr.String()) + if err != nil { + return err + } + header.Set("X-Forwarded-For", host) + return nil +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..4e3b163 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,13 @@ +package middleware + +import ( + "tunnel_pls/internal/http/header" +) + +type RequestMiddleware interface { + HandleRequest(header header.RequestHeader) error +} + +type ResponseMiddleware interface { + HandleResponse(header header.ResponseHeader, body []byte) error +} diff --git a/internal/middleware/tunnelfingerprint.go b/internal/middleware/tunnelfingerprint.go new file mode 100644 index 0000000..68171c2 --- /dev/null +++ b/internal/middleware/tunnelfingerprint.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "tunnel_pls/internal/http/header" +) + +type TunnelFingerprint struct{} + +func NewTunnelFingerprint() *TunnelFingerprint { + return &TunnelFingerprint{} +} + +func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error { + header.Set("Server", "Tunnel Please") + return nil +} diff --git a/internal/port/port.go b/internal/port/port.go index 01ecf96..6c60fbb 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -6,37 +6,37 @@ import ( "sync" ) -type Registry interface { - AddPortRange(startPort, endPort uint16) error - GetUnassignedPort() (uint16, bool) - SetPortStatus(port uint16, assigned bool) error - ClaimPort(port uint16) (claimed bool) +type Port interface { + AddRange(startPort, endPort uint16) error + Unassigned() (uint16, bool) + SetStatus(port uint16, assigned bool) error + Claim(port uint16) (claimed bool) } -type registry struct { +type port struct { mu sync.RWMutex ports map[uint16]bool sortedPorts []uint16 } -func New() Registry { - return ®istry{ +func New() Port { + return &port{ ports: make(map[uint16]bool), sortedPorts: []uint16{}, } } -func (pm *registry) AddPortRange(startPort, endPort uint16) error { +func (pm *port) AddRange(startPort, endPort uint16) error { pm.mu.Lock() defer pm.mu.Unlock() if startPort > endPort { return fmt.Errorf("start port cannot be greater than end port") } - for port := startPort; port <= endPort; port++ { - if _, exists := pm.ports[port]; !exists { - pm.ports[port] = false - pm.sortedPorts = append(pm.sortedPorts, port) + for index := startPort; index <= endPort; index++ { + if _, exists := pm.ports[index]; !exists { + pm.ports[index] = false + pm.sortedPorts = append(pm.sortedPorts, index) } } sort.Slice(pm.sortedPorts, func(i, j int) bool { @@ -45,19 +45,19 @@ func (pm *registry) AddPortRange(startPort, endPort uint16) error { return nil } -func (pm *registry) GetUnassignedPort() (uint16, bool) { +func (pm *port) Unassigned() (uint16, bool) { pm.mu.Lock() defer pm.mu.Unlock() - for _, port := range pm.sortedPorts { - if !pm.ports[port] { - return port, true + for _, index := range pm.sortedPorts { + if !pm.ports[index] { + return index, true } } return 0, false } -func (pm *registry) SetPortStatus(port uint16, assigned bool) error { +func (pm *port) SetStatus(port uint16, assigned bool) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -65,7 +65,7 @@ func (pm *registry) SetPortStatus(port uint16, assigned bool) error { return nil } -func (pm *registry) ClaimPort(port uint16) (claimed bool) { +func (pm *port) Claim(port uint16) (claimed bool) { pm.mu.Lock() defer pm.mu.Unlock() diff --git a/session/registry.go b/internal/registry/registry.go similarity index 90% rename from session/registry.go rename to internal/registry/registry.go index 6698cf1..22e590a 100644 --- a/session/registry.go +++ b/internal/registry/registry.go @@ -1,13 +1,25 @@ -package session +package registry import ( "fmt" "sync" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" "tunnel_pls/types" ) type Key = types.SessionKey +type Session interface { + Lifecycle() lifecycle.Lifecycle + Interaction() interaction.Interaction + Forwarder() forwarder.Forwarder + Slug() slug.Slug + Detail() *types.Detail +} + type Registry interface { Get(key Key) (session Session, err error) GetWithUser(user string, key Key) (session Session, err error) @@ -35,12 +47,12 @@ func (r *registry) Get(key Key) (session Session, err error) { userID, ok := r.slugIndex[key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, fmt.Errorf("Session not found") } client, ok := r.byUser[userID][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, fmt.Errorf("Session not found") } return client, nil } @@ -51,7 +63,7 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error client, ok := r.byUser[user][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, fmt.Errorf("Session not found") } return client, nil } @@ -81,7 +93,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { } client, ok := r.byUser[user][oldKey] if !ok { - return fmt.Errorf("session not found") + return fmt.Errorf("Session not found") } delete(r.byUser[user], oldKey) @@ -97,7 +109,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { return nil } -func (r *registry) Register(key Key, session Session) (success bool) { +func (r *registry) Register(key Key, userSession Session) (success bool) { r.mu.Lock() defer r.mu.Unlock() @@ -105,12 +117,12 @@ func (r *registry) Register(key Key, session Session) (success bool) { return false } - userID := session.Lifecycle().User() + userID := userSession.Lifecycle().User() if r.byUser[userID] == nil { r.byUser[userID] = make(map[Key]Session) } - r.byUser[userID][key] = session + r.byUser[userID][key] = userSession r.slugIndex[key] = userID return true } diff --git a/internal/transport/http.go b/internal/transport/http.go new file mode 100644 index 0000000..bf698ab --- /dev/null +++ b/internal/transport/http.go @@ -0,0 +1,40 @@ +package transport + +import ( + "errors" + "log" + "net" + "tunnel_pls/internal/registry" +) + +type httpServer struct { + handler *httpHandler + port string +} + +func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { + return &httpServer{ + handler: newHTTPHandler(sessionRegistry, redirectTLS), + port: port, + } +} + +func (ht *httpServer) Listen() (net.Listener, error) { + return net.Listen("tcp", ":"+ht.port) +} + +func (ht *httpServer) Serve(listener net.Listener) error { + log.Printf("HTTP server is starting on port %s", ht.port) + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return err + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go ht.handler.handler(conn, false) + } +} diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go new file mode 100644 index 0000000..0b22e48 --- /dev/null +++ b/internal/transport/httphandler.go @@ -0,0 +1,227 @@ +package transport + +import ( + "bufio" + "errors" + "fmt" + "log" + "net" + "net/http" + "strings" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/http/header" + "tunnel_pls/internal/http/stream" + "tunnel_pls/internal/middleware" + "tunnel_pls/internal/registry" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type httpHandler struct { + sessionRegistry registry.Registry + redirectTLS bool +} + +func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { + return &httpHandler{ + sessionRegistry: sessionRegistry, + redirectTLS: redirectTLS, + } +} + +func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error { + _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + + fmt.Sprintf("Location: %s", location) + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "\r\n")) + if err != nil { + return err + } + return nil +} + +func (hh *httpHandler) badRequest(conn net.Conn) error { + if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + return err + } + return nil +} + +func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { + defer hh.closeConnection(conn) + + dstReader := bufio.NewReader(conn) + reqhf, err := header.NewRequest(dstReader) + if err != nil { + log.Printf("Error creating request header: %v", err) + return + } + + slug, err := hh.extractSlug(reqhf) + if err != nil { + _ = hh.badRequest(conn) + return + } + + if hh.shouldRedirectToTLS(isTLS) { + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) + return + } + + if hh.handlePingRequest(slug, conn) { + return + } + + sshSession, err := hh.getSession(slug) + if err != nil { + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) + return + } + + hw := stream.New(conn, dstReader, conn.RemoteAddr()) + defer func(hw stream.HTTP) { + err = hw.Close() + if err != nil { + log.Printf("Error closing HTTP stream: %v", err) + } + }(hw) + hh.forwardRequest(hw, reqhf, sshSession) +} + +func (hh *httpHandler) closeConnection(conn net.Conn) { + err := conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + log.Printf("Error closing connection: %v", err) + } +} + +func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { + host := strings.Split(reqhf.Value("Host"), ".") + if len(host) < 1 { + return "", errors.New("invalid host") + } + return host[0], nil +} + +func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { + return !isTLS && hh.redirectTLS +} + +func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { + if slug != "ping" { + return false + } + + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) + if err != nil { + log.Println("Failed to write 200 OK:", err) + } + return true +} + +func (hh *httpHandler) getSession(slug string) (registry.Session, error) { + sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ + Id: slug, + Type: types.HTTP, + }) + if err != nil { + return nil, err + } + return sshSession, nil +} + +func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { + channel, err := hh.openForwardedChannel(hw, sshSession) + defer func() { + err = channel.Close() + if err != nil { + log.Printf("Error closing forwarded channel: %v", err) + } + }() + if err != nil { + log.Printf("Failed to establish channel: %v", err) + sshSession.Forwarder().WriteBadGatewayResponse(hw) + return + } + hh.setupMiddlewares(hw) + + if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil { + log.Printf("Failed to forward initial request: %v", err) + return + } + sshSession.Forwarder().HandleConnection(hw, channel) +} + +func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) { + payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) + + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + + resultChan := make(chan channelResult, 1) + + go func() { + channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + hh.cleanupUnusedChannel(channel, reqs) + } + }() + + select { + case result := <-resultChan: + if result.err != nil { + return nil, result.err + } + go ssh.DiscardRequests(result.reqs) + return result.channel, nil + case <-time.After(5 * time.Second): + return nil, errors.New("timeout opening forwarded-tcpip channel") + } +} + +func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { + if channel != nil { + if err := channel.Close(); err != nil { + log.Printf("Failed to close unused channel: %v", err) + } + go ssh.DiscardRequests(reqs) + } +} + +func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { + fingerprintMiddleware := middleware.NewTunnelFingerprint() + forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) + + hw.UseResponseMiddleware(fingerprintMiddleware) + hw.UseRequestMiddleware(forwardedForMiddleware) +} + +func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error { + hw.SetRequestHeader(initialRequest) + + if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { + return fmt.Errorf("error applying request middlewares: %w", err) + } + + if _, err := channel.Write(initialRequest.Finalize()); err != nil { + return fmt.Errorf("error writing to channel: %w", err) + } + + return nil +} diff --git a/internal/transport/https.go b/internal/transport/https.go new file mode 100644 index 0000000..104aa15 --- /dev/null +++ b/internal/transport/https.go @@ -0,0 +1,48 @@ +package transport + +import ( + "crypto/tls" + "errors" + "log" + "net" + "tunnel_pls/internal/registry" +) + +type https struct { + httpHandler *httpHandler + domain string + port string +} + +func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { + return &https{ + httpHandler: newHTTPHandler(sessionRegistry, redirectTLS), + domain: domain, + port: port, + } +} + +func (ht *https) Listen() (net.Listener, error) { + tlsConfig, err := NewTLSConfig(ht.domain) + if err != nil { + return nil, err + } + + return tls.Listen("tcp", ":"+ht.port, tlsConfig) + +} +func (ht *https) Serve(listener net.Listener) error { + log.Printf("HTTPS server is starting on port %s", ht.port) + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return err + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go ht.httpHandler.handler(conn, true) + } +} diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go new file mode 100644 index 0000000..99670d2 --- /dev/null +++ b/internal/transport/tcp.go @@ -0,0 +1,66 @@ +package transport + +import ( + "errors" + "fmt" + "io" + "log" + "net" + + "golang.org/x/crypto/ssh" +) + +type tcp struct { + port uint16 + forwarder forwarder +} + +type forwarder interface { + CreateForwardedTCPIPPayload(origin net.Addr) []byte + OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) + HandleConnection(dst io.ReadWriter, src ssh.Channel) +} + +func NewTCPServer(port uint16, forwarder forwarder) Transport { + return &tcp{ + port: port, + forwarder: forwarder, + } +} + +func (tt *tcp) Listen() (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port)) +} + +func (tt *tcp) Serve(listener net.Listener) error { + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + log.Printf("Error accepting connection: %v", err) + continue + } + go tt.handleTcp(conn) + } +} + +func (tt *tcp) handleTcp(conn net.Conn) { + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Failed to close connection: %v", err) + } + }() + payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr()) + channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + + return + } + + go ssh.DiscardRequests(reqs) + tt.forwarder.HandleConnection(conn, channel) +} diff --git a/server/tls.go b/internal/transport/tls.go similarity index 99% rename from server/tls.go rename to internal/transport/tls.go index fc67733..0893b85 100644 --- a/server/tls.go +++ b/internal/transport/tls.go @@ -1,4 +1,4 @@ -package server +package transport import ( "context" diff --git a/internal/transport/transport.go b/internal/transport/transport.go new file mode 100644 index 0000000..ca27061 --- /dev/null +++ b/internal/transport/transport.go @@ -0,0 +1,10 @@ +package transport + +import ( + "net" +) + +type Transport interface { + Listen() (net.Listener, error) + Serve(listener net.Listener) error +} diff --git a/version/version.go b/internal/version/version.go similarity index 100% rename from version/version.go rename to internal/version/version.go diff --git a/main.go b/main.go index 2303718..6510932 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net" "net/http" _ "net/http/pprof" "os" @@ -16,9 +17,10 @@ import ( "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/key" "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" + "tunnel_pls/internal/version" "tunnel_pls/server" - "tunnel_pls/session" - "tunnel_pls/version" "golang.org/x/crypto/ssh" ) @@ -76,7 +78,7 @@ func main() { } sshConfig.AddHostKey(private) - sessionRegistry := session.NewRegistry() + sessionRegistry := registry.NewRegistry() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -131,7 +133,7 @@ func main() { log.Fatalf("Failed to parse end port: %s", err) } - if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil { + if err = portManager.AddRange(uint16(start), uint16(end)); err != nil { log.Fatalf("Failed to add port range: %s", err) } log.Printf("PortRegistry range configured: %d-%d", start, end) @@ -139,9 +141,51 @@ func main() { log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) } } + + tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true" + redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" + + go func() { + httpPort := config.Getenv("HTTP_PORT", "8080") + + var httpListener net.Listener + httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS) + httpListener, err = httpserver.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return + } + err = httpserver.Serve(httpListener) + if err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + return + } + }() + + if tlsEnabled { + go func() { + httpsPort := config.Getenv("HTTPS_PORT", "8443") + domain := config.Getenv("DOMAIN", "localhost") + + var httpListener net.Listener + httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS) + httpListener, err = httpserver.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return + } + err = httpserver.Serve(httpListener) + if err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + return + } + }() + } + var app server.Server go func() { - app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager) + sshPort := config.Getenv("PORT", "2200") + app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return diff --git a/server/http.go b/server/http.go deleted file mode 100644 index a36b6b1..0000000 --- a/server/http.go +++ /dev/null @@ -1,276 +0,0 @@ -package server - -import ( - "bufio" - "crypto/tls" - "errors" - "fmt" - "log" - "net" - "net/http" - "strings" - "time" - "tunnel_pls/internal/config" - "tunnel_pls/internal/httpheader" - "tunnel_pls/session" - "tunnel_pls/types" - - "golang.org/x/crypto/ssh" -) - -type HTTPServer interface { - ListenAndServe() error - ListenAndServeTLS() error -} -type httpServer struct { - sessionRegistry session.Registry - redirectTLS bool -} - -func NewHTTPServer(sessionRegistry session.Registry, redirectTLS bool) HTTPServer { - return &httpServer{ - sessionRegistry: sessionRegistry, - redirectTLS: redirectTLS, - } -} - -func (hs *httpServer) ListenAndServe() error { - httpPort := config.Getenv("HTTP_PORT", "8080") - listener, err := net.Listen("tcp", ":"+httpPort) - if err != nil { - return errors.New("Error listening: " + err.Error()) - } - go func() { - for { - var conn net.Conn - conn, err = listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handler(conn, false) - } - }() - return nil -} - -func (hs *httpServer) ListenAndServeTLS() error { - domain := config.Getenv("DOMAIN", "localhost") - httpsPort := config.Getenv("HTTPS_PORT", "8443") - - tlsConfig, err := NewTLSConfig(domain) - if err != nil { - return fmt.Errorf("failed to initialize TLS config: %w", err) - } - - ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig) - if err != nil { - return err - } - - go func() { - for { - var conn net.Conn - conn, err = ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - log.Println("https server closed") - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handler(conn, true) - } - }() - return nil -} - -func (hs *httpServer) redirect(conn net.Conn, status int, location string) error { - _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + - fmt.Sprintf("Location: %s", location) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - return err - } - return nil -} - -func (hs *httpServer) badRequest(conn net.Conn) error { - if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { - return err - } - return nil -} - -func (hs *httpServer) handler(conn net.Conn, isTLS bool) { - defer hs.closeConnection(conn) - - dstReader := bufio.NewReader(conn) - reqhf, err := httpheader.NewRequestHeader(dstReader) - if err != nil { - log.Printf("Error creating request header: %v", err) - return - } - - slug, err := hs.extractSlug(reqhf) - if err != nil { - _ = hs.badRequest(conn) - return - } - - if hs.shouldRedirectToTLS(isTLS) { - _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) - return - } - - if hs.handlePingRequest(slug, conn) { - return - } - - sshSession, err := hs.getSession(slug) - if err != nil { - _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) - return - } - - hw := NewHTTPWriter(conn, dstReader, conn.RemoteAddr()) - hs.forwardRequest(hw, reqhf, sshSession) -} - -func (hs *httpServer) closeConnection(conn net.Conn) { - err := conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - log.Printf("Error closing connection: %v", err) - } -} - -func (hs *httpServer) extractSlug(reqhf httpheader.RequestHeader) (string, error) { - host := strings.Split(reqhf.Value("Host"), ".") - if len(host) < 1 { - return "", errors.New("invalid host") - } - return host[0], nil -} - -func (hs *httpServer) shouldRedirectToTLS(isTLS bool) bool { - return !isTLS && hs.redirectTLS -} - -func (hs *httpServer) handlePingRequest(slug string, conn net.Conn) bool { - if slug != "ping" { - return false - } - - _, err := conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - } - return true -} - -func (hs *httpServer) getSession(slug string) (session.Session, error) { - sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.HTTP, - }) - if err != nil { - return nil, err - } - return sshSession, nil -} - -func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, sshSession session.Session) { - channel, err := hs.openForwardedChannel(hw, sshSession) - if err != nil { - log.Printf("Failed to establish channel: %v", err) - sshSession.Forwarder().WriteBadGatewayResponse(hw) - return - } - - hs.setupMiddlewares(hw) - - if err := hs.sendInitialRequest(hw, initialRequest, channel); err != nil { - log.Printf("Failed to forward initial request: %v", err) - return - } - - sshSession.Forwarder().HandleConnection(hw, channel, hw.RemoteAddr()) -} - -func (hs *httpServer) openForwardedChannel(hw HTTPWriter, sshSession session.Session) (ssh.Channel, error) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - select { - case resultChan <- channelResult{channel, reqs, err}: - default: - hs.cleanupUnusedChannel(channel, reqs) - } - }() - - select { - case result := <-resultChan: - if result.err != nil { - return nil, result.err - } - go ssh.DiscardRequests(result.reqs) - return result.channel, nil - case <-time.After(5 * time.Second): - return nil, errors.New("timeout opening forwarded-tcpip channel") - } -} - -func (hs *httpServer) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { - if channel != nil { - if err := channel.Close(); err != nil { - log.Printf("Failed to close unused channel: %v", err) - } - go ssh.DiscardRequests(reqs) - } -} - -func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { - fingerprintMiddleware := NewTunnelFingerprint() - forwardedForMiddleware := NewForwardedFor(hw.RemoteAddr()) - - hw.UseResponseMiddleware(fingerprintMiddleware) - hw.UseRequestMiddleware(forwardedForMiddleware) -} - -func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, channel ssh.Channel) error { - hw.SetRequestHeader(initialRequest) - - if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { - return fmt.Errorf("error applying request middlewares: %w", err) - } - - if _, err := channel.Write(initialRequest.Finalize()); err != nil { - return fmt.Errorf("error writing to channel: %w", err) - } - - return nil -} diff --git a/server/httpwritter.go b/server/httpwritter.go deleted file mode 100644 index 9d52f24..0000000 --- a/server/httpwritter.go +++ /dev/null @@ -1,254 +0,0 @@ -package server - -import ( - "bytes" - "io" - "log" - "net" - "regexp" - "tunnel_pls/internal/httpheader" -) - -type HTTPWriter interface { - io.ReadWriteCloser - CloseWrite() error - RemoteAddr() net.Addr - UseResponseMiddleware(mw ResponseMiddleware) - UseRequestMiddleware(mw RequestMiddleware) - SetRequestHeader(header httpheader.RequestHeader) - RequestMiddlewares() []RequestMiddleware - ResponseMiddlewares() []ResponseMiddleware - ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error - ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error -} - -type httpWriter struct { - remoteAddr net.Addr - writer io.Writer - reader io.Reader - headerBuf []byte - buf []byte - respHeader httpheader.ResponseHeader - reqHeader httpheader.RequestHeader - respMW []ResponseMiddleware - reqMW []RequestMiddleware -} - -var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} -var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) -var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) - -func (hw *httpWriter) RemoteAddr() net.Addr { - return hw.remoteAddr -} - -func (hw *httpWriter) UseResponseMiddleware(mw ResponseMiddleware) { - hw.respMW = append(hw.respMW, mw) -} - -func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { - hw.reqMW = append(hw.reqMW, mw) -} - -func (hw *httpWriter) SetRequestHeader(header httpheader.RequestHeader) { - hw.reqHeader = header -} - -func (hw *httpWriter) RequestMiddlewares() []RequestMiddleware { - return hw.reqMW -} - -func (hw *httpWriter) ResponseMiddlewares() []ResponseMiddleware { - return hw.respMW -} -func (hw *httpWriter) Close() error { - return hw.writer.(io.Closer).Close() -} - -func (hw *httpWriter) CloseWrite() error { - if closer, ok := hw.writer.(interface{ CloseWrite() error }); ok { - return closer.CloseWrite() - } - return hw.Close() -} - -func (hw *httpWriter) Read(p []byte) (int, error) { - tmp := make([]byte, len(p)) - read, err := hw.reader.Read(tmp) - if read == 0 && err != nil { - return 0, err - } - - tmp = tmp[:read] - - headerEndIdx := bytes.Index(tmp, DELIMITER) - if headerEndIdx == -1 { - return hw.handleNoDelimiter(p, tmp, err) - } - - header, body := hw.splitHeaderAndBody(tmp, headerEndIdx) - - if !isHTTPHeader(header) { - copy(p, tmp) - return read, nil - } - - return hw.processHTTPRequest(p, header, body) -} - -func (hw *httpWriter) handleNoDelimiter(p, tmp []byte, err error) (int, error) { - copy(p, tmp) - return len(tmp), err -} - -func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) { - header := data[:delimiterIdx+len(DELIMITER)] - body := data[delimiterIdx+len(DELIMITER):] - return header, body -} - -func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { - reqhf, err := httpheader.NewRequestHeader(header) - if err != nil { - return 0, err - } - - if err = hw.ApplyRequestMiddlewares(reqhf); err != nil { - return 0, err - } - - hw.reqHeader = reqhf - combined := append(reqhf.Finalize(), body...) - return copy(p, combined), nil -} - -func (hw *httpWriter) ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error { - for _, m := range hw.RequestMiddlewares() { - if err := m.HandleRequest(reqhf); err != nil { - log.Printf("Error when applying request middleware: %v", err) - return err - } - } - return nil -} - -func (hw *httpWriter) Write(p []byte) (int, error) { - if hw.shouldBypassBuffering(p) { - hw.respHeader = nil - } - - if hw.respHeader != nil { - return hw.writer.Write(p) - } - - hw.buf = append(hw.buf, p...) - - headerEndIdx := bytes.Index(hw.buf, DELIMITER) - if headerEndIdx == -1 { - return len(p), nil - } - - return hw.processBufferedResponse(p, headerEndIdx) -} - -func (hw *httpWriter) shouldBypassBuffering(p []byte) bool { - return hw.respHeader != nil && len(hw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" -} - -func (hw *httpWriter) processBufferedResponse(p []byte, delimiterIdx int) (int, error) { - header, body := hw.splitHeaderAndBody(hw.buf, delimiterIdx) - - if !isHTTPHeader(header) { - return hw.writeRawBuffer() - } - - if err := hw.processHTTPResponse(header, body); err != nil { - return 0, err - } - - hw.buf = nil - return len(p), nil -} - -func (hw *httpWriter) writeRawBuffer() (int, error) { - _, err := hw.writer.Write(hw.buf) - length := len(hw.buf) - hw.buf = nil - if err != nil { - return 0, err - } - return length, nil -} - -func (hw *httpWriter) processHTTPResponse(header, body []byte) error { - resphf, err := httpheader.NewResponseHeader(header) - if err != nil { - return err - } - - if err = hw.ApplyResponseMiddlewares(resphf, body); err != nil { - return err - } - - hw.respHeader = resphf - finalHeader := resphf.Finalize() - - if err = hw.writeHeaderAndBody(finalHeader, body); err != nil { - return err - } - - return nil -} - -func (hw *httpWriter) ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error { - for _, m := range hw.ResponseMiddlewares() { - if err := m.HandleResponse(resphf, body); err != nil { - log.Printf("Cannot apply middleware: %s\n", err) - return err - } - } - return nil -} - -func (hw *httpWriter) writeHeaderAndBody(header, body []byte) error { - if _, err := hw.writer.Write(header); err != nil { - return err - } - - if len(body) > 0 { - if _, err := hw.writer.Write(body); err != nil { - return err - } - } - - return nil -} - -func NewHTTPWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { - return &httpWriter{ - remoteAddr: remoteAddr, - writer: writer, - reader: reader, - buf: make([]byte, 0, 4096), - } -} - -func isHTTPHeader(buf []byte) bool { - lines := bytes.Split(buf, []byte("\r\n")) - - startLine := string(lines[0]) - if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { - return false - } - - for _, line := range lines[1:] { - if len(line) == 0 { - break - } - colonIdx := bytes.IndexByte(line, ':') - if colonIdx <= 0 { - return false - } - } - return true -} diff --git a/server/middleware.go b/server/middleware.go deleted file mode 100644 index 6f50c4c..0000000 --- a/server/middleware.go +++ /dev/null @@ -1,42 +0,0 @@ -package server - -import ( - "net" - "tunnel_pls/internal/httpheader" -) - -type RequestMiddleware interface { - HandleRequest(header httpheader.RequestHeader) error -} - -type ResponseMiddleware interface { - HandleResponse(header httpheader.ResponseHeader, body []byte) error -} - -type TunnelFingerprint struct{} - -func NewTunnelFingerprint() *TunnelFingerprint { - return &TunnelFingerprint{} -} - -func (h *TunnelFingerprint) HandleResponse(header httpheader.ResponseHeader, body []byte) error { - header.Set("Server", "Tunnel Please") - return nil -} - -type ForwardedFor struct { - addr net.Addr -} - -func NewForwardedFor(addr net.Addr) *ForwardedFor { - return &ForwardedFor{addr: addr} -} - -func (ff *ForwardedFor) HandleRequest(header httpheader.RequestHeader) error { - host, _, err := net.SplitHostPort(ff.addr.String()) - if err != nil { - return err - } - header.Set("X-Forwarded-For", host) - return nil -} diff --git a/server/server.go b/server/server.go index 868b9e6..185d051 100644 --- a/server/server.go +++ b/server/server.go @@ -7,9 +7,9 @@ import ( "log" "net" "time" - "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" "tunnel_pls/session" "golang.org/x/crypto/ssh" @@ -20,38 +20,23 @@ type Server interface { Close() error } type server struct { - listener net.Listener + sshPort string + sshListener net.Listener config *ssh.ServerConfig grpcClient client.Client - sessionRegistry session.Registry - portRegistry port.Registry + sessionRegistry registry.Registry + portRegistry port.Port } -func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) { - listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) +func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { + listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) if err != nil { - log.Fatalf("failed to listen on port 2200: %v", err) return nil, err } - redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" - - HttpServer := NewHTTPServer(sessionRegistry, redirectTLS) - err = HttpServer.ListenAndServe() - if err != nil { - log.Fatalf("failed to start http server: %v", err) - return nil, err - } - - if config.Getenv("TLS_ENABLED", "false") == "true" { - err = HttpServer.ListenAndServeTLS() - if err != nil { - log.Fatalf("failed to start https server: %v", err) - return nil, err - } - } return &server{ - listener: listener, + sshPort: sshPort, + sshListener: listener, config: sshConfig, grpcClient: grpcClient, sessionRegistry: sessionRegistry, @@ -60,9 +45,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie } func (s *server) Start() { - log.Println("SSH server is starting on port 2200...") + log.Printf("SSH server is starting on port %s", s.sshPort) for { - conn, err := s.listener.Accept() + conn, err := s.sshListener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { log.Println("listener closed, stopping server") @@ -77,7 +62,7 @@ func (s *server) Start() { } func (s *server) Close() error { - return s.listener.Close() + return s.sshListener.Close() } func (s *server) handleConnection(conn net.Conn) { diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index cac6691..ff2abde 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -56,14 +56,14 @@ type Forwarder interface { Listener() net.Listener TunnelType() types.TunnelType ForwardedPort() uint16 - HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) + HandleConnection(dst io.ReadWriter, src ssh.Channel) CreateForwardedTCPIPPayload(origin net.Addr) []byte + OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) WriteBadGatewayResponse(dst io.Writer) - AcceptTCPConnections() Close() error } -func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { +func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { type channelResult struct { channel ssh.Channel reqs <-chan *ssh.Request @@ -95,38 +95,6 @@ func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *s } } -func (f *forwarder) handleIncomingConnection(conn net.Conn) { - payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - - channel, reqs, err := f.openForwardedChannel(payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - err = conn.Close() - if err != nil { - log.Printf("Failed to close connection: %v", err) - } - return - } - - go ssh.DiscardRequests(reqs) - go f.HandleConnection(conn, channel, conn.RemoteAddr()) -} - -func (f *forwarder) AcceptTCPConnections() { - for { - conn, err := f.Listener().Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go f.handleIncomingConnection(conn) - } -} - func closeWriter(w io.Writer) error { if cw, ok := w.(interface{ CloseWrite() error }); ok { return cw.CloseWrite() @@ -145,12 +113,12 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) } if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) { - errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err)) + errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err)) } return errors.Join(errs...) } -func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { +func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { defer func() { _, err := io.Copy(io.Discard, src) if err != nil { @@ -158,8 +126,6 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA } }() - log.Printf("Handling new forwarded connection from %s", remoteAddr) - var wg sync.WaitGroup wg.Add(2) diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 8d134a2..7a2fcaf 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -31,11 +31,11 @@ type lifecycle struct { slug slug.Slug startedAt time.Time sessionRegistry SessionRegistry - portRegistry portUtil.Registry + portRegistry portUtil.Port user string } -func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle { +func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle { return &lifecycle{ status: types.INITIALIZING, conn: conn, @@ -51,7 +51,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti type Lifecycle interface { Connection() ssh.Conn - PortRegistry() portUtil.Registry + PortRegistry() portUtil.Port User() string SetChannel(channel ssh.Channel) SetStatus(status types.Status) @@ -60,7 +60,7 @@ type Lifecycle interface { Close() error } -func (l *lifecycle) PortRegistry() portUtil.Registry { +func (l *lifecycle) PortRegistry() portUtil.Port { return l.portRegistry } @@ -113,7 +113,7 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TCP { - if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { + if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { firstErr = err } } diff --git a/session/session.go b/session/session.go index 4118550..d113084 100644 --- a/session/session.go +++ b/session/session.go @@ -12,6 +12,8 @@ import ( "tunnel_pls/internal/config" portUtil "tunnel_pls/internal/port" "tunnel_pls/internal/random" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" @@ -21,14 +23,6 @@ import ( "golang.org/x/crypto/ssh" ) -type Detail struct { - ForwardingType string `json:"forwarding_type,omitempty"` - Slug string `json:"slug,omitempty"` - UserID string `json:"user_id,omitempty"` - Active bool `json:"active,omitempty"` - StartedAt time.Time `json:"started_at,omitempty"` -} - type Session interface { HandleGlobalRequest(ch <-chan *ssh.Request) error HandleTCPIPForward(req *ssh.Request) error @@ -38,7 +32,7 @@ type Session interface { Interaction() interaction.Interaction Forwarder() forwarder.Forwarder Slug() slug.Slug - Detail() *Detail + Detail() *types.Detail Start() error } @@ -49,12 +43,12 @@ type session struct { interaction interaction.Interaction forwarder forwarder.Forwarder slug slug.Slug - registry Registry + registry registry.Registry } var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session { +func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { slugManager := slug.New() forwarderManager := forwarder.New(slugManager, conn) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) @@ -87,7 +81,7 @@ func (s *session) Slug() slug.Slug { return s.slug } -func (s *session) Detail() *Detail { +func (s *session) Detail() *types.Detail { tunnelTypeMap := map[types.TunnelType]string{ types.HTTP: "HTTP", types.TCP: "TCP", @@ -97,7 +91,7 @@ func (s *session) Detail() *Detail { tunnelType = "UNKNOWN" } - return &Detail{ + return &types.Detail{ ForwardingType: tunnelType, Slug: s.slug.String(), UserID: s.lifecycle.User(), @@ -271,7 +265,7 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, } if port == 0 { - unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() + unassigned, ok := s.lifecycle.PortRegistry().Unassigned() if !ok { return "", 0, fmt.Errorf("no available port") } @@ -328,7 +322,6 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen if listener != nil { s.forwarder.SetListener(listener) - go s.forwarder.AcceptTCPConnections() } return nil @@ -346,7 +339,6 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error { case 80, 443: return s.HandleHTTPForward(req, port) default: - return s.HandleTCPForward(req, address, port) } } @@ -369,11 +361,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { } func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { - if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed { + if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) + tcpServer := transport.NewTCPServer(portToBind, s.forwarder) + listener, err := tcpServer.Listen() if err != nil { return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } @@ -387,6 +380,14 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin if err != nil { return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } + + go func() { + err = tcpServer.Serve(listener) + if err != nil { + log.Printf("Failed serving tcp server: %s\n", err) + } + }() + return nil } diff --git a/types/types.go b/types/types.go index 148cd2b..b91dffb 100644 --- a/types/types.go +++ b/types/types.go @@ -1,5 +1,7 @@ package types +import "time" + type Status int const ( @@ -27,6 +29,14 @@ type SessionKey struct { Type TunnelType } +type Detail struct { + ForwardingType string `json:"forwarding_type,omitempty"` + Slug string `json:"slug,omitempty"` + UserID string `json:"user_id,omitempty"` + Active bool `json:"active,omitempty"` + StartedAt time.Time `json:"started_at,omitempty"` +} + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" + -- 2.49.1 From 2bc20dd99154500e369eea94345e6e1d131c029d Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 21 Jan 2026 19:43:19 +0700 Subject: [PATCH 08/11] refactor(config): centralize env loading and enforce typed access - Centralize environment variable loading in config.MustLoad - Parse and validate all env vars once at initialization - Make config fields private and read-only - Remove public Getenv usage in favor of typed accessors - Improve validation and initialization order - Normalize enum naming to be idiomatic and avoid constant collisions --- internal/config/config.go | 76 +++++++++---- internal/config/loader.go | 170 +++++++++++++++++++++++++++++ internal/grpc/client/client.go | 16 +-- internal/http/header/header.go | 6 +- internal/http/header/request.go | 6 +- internal/registry/registry.go | 10 +- internal/transport/http.go | 4 +- internal/transport/httphandler.go | 32 +++--- internal/transport/https.go | 15 +-- internal/transport/tls.go | 45 ++------ main.go | 103 ++++++----------- server/server.go | 13 ++- session/forwarder/forwarder.go | 63 ++++++----- session/interaction/interaction.go | 21 ++-- session/interaction/model.go | 2 +- session/interaction/slug.go | 12 +- session/lifecycle/lifecycle.go | 14 +-- session/session.go | 36 +++--- types/types.go | 27 +++-- 19 files changed, 414 insertions(+), 257 deletions(-) create mode 100644 internal/config/loader.go diff --git a/internal/config/config.go b/internal/config/config.go index 45f1cc5..62e1aca 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,33 +1,63 @@ package config -import ( - "os" - "strconv" +import "tunnel_pls/types" - "github.com/joho/godotenv" -) +type Config interface { + Domain() string + SSHPort() string -func Load() error { - if _, err := os.Stat(".env"); err == nil { - return godotenv.Load(".env") - } - return nil + HTTPPort() string + HTTPSPort() string + + TLSEnabled() bool + TLSRedirect() bool + + ACMEEmail() string + CFAPIToken() string + ACMEStaging() bool + + AllowedPortsStart() uint16 + AllowedPortsEnd() uint16 + + BufferSize() int + + PprofEnabled() bool + PprofPort() string + + Mode() types.ServerMode + GRPCAddress() string + GRPCPort() string + NodeToken() string } -func Getenv(key, defaultValue string) string { - val := os.Getenv(key) - if val == "" { - val = defaultValue +func MustLoad() (Config, error) { + if err := loadEnvFile(); err != nil { + return nil, err } - return val + cfg, err := parse() + if err != nil { + return nil, err + } + + return cfg, nil } -func GetBufferSize() int { - sizeStr := Getenv("BUFFER_SIZE", "32768") - size, err := strconv.Atoi(sizeStr) - if err != nil || size < 4096 || size > 1048576 { - return 32768 - } - return size -} +func (c *config) Domain() string { return c.domain } +func (c *config) SSHPort() string { return c.sshPort } +func (c *config) HTTPPort() string { return c.httpPort } +func (c *config) HTTPSPort() string { return c.httpsPort } +func (c *config) TLSEnabled() bool { return c.tlsEnabled } +func (c *config) TLSRedirect() bool { return c.tlsRedirect } +func (c *config) ACMEEmail() string { return c.acmeEmail } +func (c *config) CFAPIToken() string { return c.cfAPIToken } +func (c *config) ACMEStaging() bool { return c.acmeStaging } +func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart } +func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd } +func (c *config) BufferSize() int { return c.bufferSize } +func (c *config) PprofEnabled() bool { return c.pprofEnabled } +func (c *config) PprofPort() string { return c.pprofPort } +func (c *config) Mode() types.ServerMode { return c.mode } +func (c *config) GRPCAddress() string { return c.grpcAddress } +func (c *config) GRPCPort() string { return c.grpcPort } +func (c *config) NodeToken() string { return c.nodeToken } diff --git a/internal/config/loader.go b/internal/config/loader.go new file mode 100644 index 0000000..cde9fd0 --- /dev/null +++ b/internal/config/loader.go @@ -0,0 +1,170 @@ +package config + +import ( + "fmt" + "log" + "os" + "strconv" + "strings" + "tunnel_pls/types" + + "github.com/joho/godotenv" +) + +type config struct { + domain string + sshPort string + + httpPort string + httpsPort string + + tlsEnabled bool + tlsRedirect bool + + acmeEmail string + cfAPIToken string + acmeStaging bool + + allowedPortsStart uint16 + allowedPortsEnd uint16 + + bufferSize int + + pprofEnabled bool + pprofPort string + + mode types.ServerMode + grpcAddress string + grpcPort string + nodeToken string +} + +func parse() (*config, error) { + mode, err := parseMode() + if err != nil { + return nil, err + } + + domain := getenv("DOMAIN", "localhost") + sshPort := getenv("PORT", "2200") + + httpPort := getenv("HTTP_PORT", "8080") + httpsPort := getenv("HTTPS_PORT", "8443") + + tlsEnabled := getenvBool("TLS_ENABLED", false) + tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false) + + acmeEmail := getenv("ACME_EMAIL", "admin@"+domain) + acmeStaging := getenvBool("ACME_STAGING", false) + + cfToken := getenv("CF_API_TOKEN", "") + if tlsEnabled && cfToken == "" { + return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled") + } + + start, end, err := parseAllowedPorts() + if err != nil { + return nil, err + } + + bufferSize := parseBufferSize() + + pprofEnabled := getenvBool("PPROF_ENABLED", false) + pprofPort := getenv("PPROF_PORT", "6060") + + grpcHost := getenv("GRPC_ADDRESS", "localhost") + grpcPort := getenv("GRPC_PORT", "8080") + + nodeToken := getenv("NODE_TOKEN", "") + if mode == types.ServerModeNODE && nodeToken == "" { + return nil, fmt.Errorf("NODE_TOKEN is required in node mode") + } + + return &config{ + domain: domain, + sshPort: sshPort, + httpPort: httpPort, + httpsPort: httpsPort, + tlsEnabled: tlsEnabled, + tlsRedirect: tlsRedirect, + acmeEmail: acmeEmail, + cfAPIToken: cfToken, + acmeStaging: acmeStaging, + allowedPortsStart: start, + allowedPortsEnd: end, + bufferSize: bufferSize, + pprofEnabled: pprofEnabled, + pprofPort: pprofPort, + mode: mode, + grpcAddress: grpcHost, + grpcPort: grpcPort, + nodeToken: nodeToken, + }, nil +} + +func loadEnvFile() error { + if _, err := os.Stat(".env"); err == nil { + return godotenv.Load(".env") + } + return nil +} + +func parseMode() (types.ServerMode, error) { + switch strings.ToLower(getenv("MODE", "standalone")) { + case "standalone": + return types.ServerModeSTANDALONE, nil + case "node": + return types.ServerModeNODE, nil + default: + return 0, fmt.Errorf("invalid MODE value") + } +} + +func parseAllowedPorts() (uint16, uint16, error) { + raw := getenv("ALLOWED_PORTS", "") + if raw == "" { + return 0, 0, nil + } + + parts := strings.Split(raw, "-") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format") + } + + start, err := strconv.ParseUint(parts[0], 10, 16) + if err != nil { + return 0, 0, err + } + + end, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return 0, 0, err + } + + return uint16(start), uint16(end), nil +} + +func parseBufferSize() int { + raw := getenv("BUFFER_SIZE", "32768") + size, err := strconv.Atoi(raw) + if err != nil || size < 4096 || size > 1048576 { + log.Println("Invalid BUFFER_SIZE, falling back to 4096") + return 4096 + } + return size +} + +func getenv(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +func getenvBool(key string, def bool) bool { + val := os.Getenv(key) + if val == "" { + return def + } + return val == "true" +} diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 0874afe..f2e0a1e 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -29,6 +29,7 @@ type Client interface { CheckServerHealth(ctx context.Context) error } type client struct { + config config.Config conn *grpc.ClientConn address string sessionRegistry registry.Registry @@ -37,7 +38,7 @@ type client struct { closing bool } -func New(address string, sessionRegistry registry.Registry) (Client, error) { +func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -66,6 +67,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) { authorizeConnectionService := proto.NewUserServiceClient(conn) return &client{ + config: config, conn: conn, address: address, sessionRegistry: sessionRegistry, @@ -192,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, oldSlug := slugEvent.GetOld() newSlug := slugEvent.GetNew() - userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP}) + userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}) if err != nil { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, @@ -202,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, }, "slug change failure response") } - if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil { + if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, Payload: &proto.Node_SlugEventResponse{ @@ -227,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node for _, ses := range sessions { detail := ses.Detail() details = append(details, &proto.Detail{ - Node: config.Getenv("DOMAIN", "localhost"), + Node: c.config.Domain(), ForwardingType: detail.ForwardingType, Slug: detail.Slug, UserId: detail.UserID, @@ -299,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { switch t { case proto.TunnelType_HTTP: - return types.HTTP, nil + return types.TunnelTypeHTTP, nil case proto.TunnelType_TCP: - return types.TCP, nil + return types.TunnelTypeTCP, nil default: - return types.UNKNOWN, fmt.Errorf("unknown tunnel type received") + return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received") } } diff --git a/internal/http/header/header.go b/internal/http/header/header.go index a5e52b3..605f9ec 100644 --- a/internal/http/header/header.go +++ b/internal/http/header/header.go @@ -17,9 +17,9 @@ type RequestHeader interface { Set(key string, value string) Remove(key string) Finalize() []byte - GetMethod() string - GetPath() string - GetVersion() string + Method() string + Path() string + Version() string } type requestHeader struct { method string diff --git a/internal/http/header/request.go b/internal/http/header/request.go index b05f699..1fbe57a 100644 --- a/internal/http/header/request.go +++ b/internal/http/header/request.go @@ -32,15 +32,15 @@ func (req *requestHeader) Remove(key string) { delete(req.headers, key) } -func (req *requestHeader) GetMethod() string { +func (req *requestHeader) Method() string { return req.method } -func (req *requestHeader) GetPath() string { +func (req *requestHeader) Path() string { return req.path } -func (req *requestHeader) GetVersion() string { +func (req *requestHeader) Version() string { return req.version } diff --git a/internal/registry/registry.go b/internal/registry/registry.go index 22e590a..86898b0 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -47,12 +47,12 @@ func (r *registry) Get(key Key) (session Session, err error) { userID, ok := r.slugIndex[key] if !ok { - return nil, fmt.Errorf("Session not found") + return nil, fmt.Errorf("session not found") } client, ok := r.byUser[userID][key] if !ok { - return nil, fmt.Errorf("Session not found") + return nil, fmt.Errorf("session not found") } return client, nil } @@ -63,7 +63,7 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error client, ok := r.byUser[user][key] if !ok { - return nil, fmt.Errorf("Session not found") + return nil, fmt.Errorf("session not found") } return client, nil } @@ -73,7 +73,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { return fmt.Errorf("tunnel type cannot change") } - if newKey.Type != types.HTTP { + if newKey.Type != types.TunnelTypeHTTP { return fmt.Errorf("non http tunnel cannot change slug") } @@ -93,7 +93,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { } client, ok := r.byUser[user][oldKey] if !ok { - return fmt.Errorf("Session not found") + return fmt.Errorf("session not found") } delete(r.byUser[user], oldKey) diff --git a/internal/transport/http.go b/internal/transport/http.go index bf698ab..dd091c3 100644 --- a/internal/transport/http.go +++ b/internal/transport/http.go @@ -12,9 +12,9 @@ type httpServer struct { port string } -func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { +func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { return &httpServer{ - handler: newHTTPHandler(sessionRegistry, redirectTLS), + handler: newHTTPHandler(domain, sessionRegistry, redirectTLS), port: port, } } diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 0b22e48..b6f128d 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -4,12 +4,12 @@ import ( "bufio" "errors" "fmt" + "io" "log" "net" "net/http" "strings" "time" - "tunnel_pls/internal/config" "tunnel_pls/internal/http/header" "tunnel_pls/internal/http/stream" "tunnel_pls/internal/middleware" @@ -20,19 +20,21 @@ import ( ) type httpHandler struct { + domain string sessionRegistry registry.Registry redirectTLS bool } -func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { +func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { return &httpHandler{ + domain: domain, sessionRegistry: sessionRegistry, redirectTLS: redirectTLS, } } func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error { - _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + + _, err := conn.Write([]byte(fmt.Sprintf("TunnelTypeHTTP/1.1 %d Moved Permanently\r\n", status) + fmt.Sprintf("Location: %s", location) + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -44,7 +46,7 @@ func (hh *httpHandler) redirect(conn net.Conn, status int, location string) erro } func (hh *httpHandler) badRequest(conn net.Conn) error { - if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + if _, err := conn.Write([]byte("TunnelTypeHTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { return err } return nil @@ -67,7 +69,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { } if hh.shouldRedirectToTLS(isTLS) { - _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain)) return } @@ -85,7 +87,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { defer func(hw stream.HTTP) { err = hw.Close() if err != nil { - log.Printf("Error closing HTTP stream: %v", err) + log.Printf("Error closing TunnelTypeHTTP stream: %v", err) } }(hw) hh.forwardRequest(hw, reqhf, sshSession) @@ -116,7 +118,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { } _, err := conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + + "TunnelTypeHTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + "Access-Control-Allow-Origin: *\r\n" + @@ -133,7 +135,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { func (hh *httpHandler) getSession(slug string) (registry.Session, error) { sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ Id: slug, - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }) if err != nil { return nil, err @@ -143,17 +145,19 @@ func (hh *httpHandler) getSession(slug string) (registry.Session, error) { func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { channel, err := hh.openForwardedChannel(hw, sshSession) - defer func() { - err = channel.Close() - if err != nil { - log.Printf("Error closing forwarded channel: %v", err) - } - }() if err != nil { log.Printf("Failed to establish channel: %v", err) sshSession.Forwarder().WriteBadGatewayResponse(hw) return } + + defer func() { + err = channel.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing forwarded channel: %v", err) + } + }() + hh.setupMiddlewares(hw) if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil { diff --git a/internal/transport/https.go b/internal/transport/https.go index 104aa15..88ffe27 100644 --- a/internal/transport/https.go +++ b/internal/transport/https.go @@ -9,28 +9,25 @@ import ( ) type https struct { + tlsConfig *tls.Config httpHandler *httpHandler domain string port string } -func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { +func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport { return &https{ - httpHandler: newHTTPHandler(sessionRegistry, redirectTLS), + tlsConfig: tlsConfig, + httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS), domain: domain, port: port, } } func (ht *https) Listen() (net.Listener, error) { - tlsConfig, err := NewTLSConfig(ht.domain) - if err != nil { - return nil, err - } - - return tls.Listen("tcp", ":"+ht.port, tlsConfig) - + return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig) } + func (ht *https) Serve(listener net.Listener) error { log.Printf("HTTPS server is starting on port %s", ht.port) for { diff --git a/internal/transport/tls.go b/internal/transport/tls.go index 0893b85..6824a54 100644 --- a/internal/transport/tls.go +++ b/internal/transport/tls.go @@ -26,7 +26,8 @@ type TLSManager interface { } type tlsManager struct { - domain string + config config.Config + certPath string keyPath string storagePath string @@ -42,7 +43,7 @@ type tlsManager struct { var globalTLSManager TLSManager var tlsManagerOnce sync.Once -func NewTLSConfig(domain string) (*tls.Config, error) { +func NewTLSConfig(config config.Config) (*tls.Config, error) { var initErr error tlsManagerOnce.Do(func() { @@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { storagePath := "certs/tls/certmagic" tm := &tlsManager{ - domain: domain, + config: config, certPath: certPath, keyPath: keyPath, storagePath: storagePath, @@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { tm.useCertMagic = false tm.startCertWatcher() } else { - if !isACMEConfigComplete() { - log.Printf("User certificates missing or invalid, and ACME configuration is incomplete") - log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable") - initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)") - return - } - - log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain) + log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain, config.Domain) if err := tm.initCertMagic(); err != nil { initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) return @@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) { return globalTLSManager.getTLSConfig(), nil } -func isACMEConfigComplete() bool { - cfAPIToken := config.Getenv("CF_API_TOKEN", "") - return cfAPIToken != "" -} - func (tm *tlsManager) userCertsExistAndValid() bool { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { log.Printf("Certificate file not found: %s", tm.certPath) @@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool { return false } - return ValidateCertDomains(tm.certPath, tm.domain) + return ValidateCertDomains(tm.certPath, tm.config.Domain()) } func ValidateCertDomains(certPath, domain string) bool { @@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() { if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { log.Printf("Certificate files changed, reloading...") - if !ValidateCertDomains(tm.certPath, tm.domain) { + if !ValidateCertDomains(tm.certPath, tm.config.Domain()) { log.Printf("New certificates don't cover required domains") - if !isACMEConfigComplete() { - log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)") - continue - } - - log.Printf("Switching to CertMagic for automatic certificate management") if err := tm.initCertMagic(); err != nil { log.Printf("Failed to initialize CertMagic: %v", err) continue @@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error { return fmt.Errorf("failed to create cert storage directory: %w", err) } - acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain) - cfAPIToken := config.Getenv("CF_API_TOKEN", "") - acmeStaging := config.Getenv("ACME_STAGING", "false") == "true" - - if cfAPIToken == "" { + if tm.config.CFAPIToken() == "" { return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") } cfProvider := &cloudflare.Provider{ - APIToken: cfAPIToken, + APIToken: tm.config.CFAPIToken(), } storage := &certmagic.FileStorage{Path: tm.storagePath} @@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error { }) acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ - Email: acmeEmail, + Email: tm.config.ACMEEmail(), Agreed: true, DNS01Solver: &certmagic.DNS01Solver{ DNSManager: certmagic.DNSManager{ @@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error { }, }) - if acmeStaging { + if tm.config.ACMEStaging() { acmeIssuer.CA = certmagic.LetsEncryptStagingCA log.Printf("Using Let's Encrypt staging server") } else { @@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error { magic.Issuers = []certmagic.Issuer{acmeIssuer} tm.magic = magic - domains := []string{tm.domain, "*." + tm.domain} + domains := []string{tm.config.Domain(), "*." + tm.config.Domain()} log.Printf("Requesting certificates for: %v", domains) ctx := context.Background() diff --git a/main.go b/main.go index 6510932..f897b46 100644 --- a/main.go +++ b/main.go @@ -9,8 +9,6 @@ import ( _ "net/http/pprof" "os" "os/signal" - "strconv" - "strings" "syscall" "time" "tunnel_pls/internal/config" @@ -21,6 +19,7 @@ import ( "tunnel_pls/internal/transport" "tunnel_pls/internal/version" "tunnel_pls/server" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) @@ -36,27 +35,12 @@ func main() { log.Printf("Starting %s", version.GetVersion()) - err := config.Load() + conf, err := config.MustLoad() if err != nil { log.Fatalf("Failed to load configuration: %s", err) return } - mode := strings.ToLower(config.Getenv("MODE", "standalone")) - isNodeMode := mode == "node" - - pprofEnabled := config.Getenv("PPROF_ENABLED", "false") - if pprofEnabled == "true" { - pprofPort := config.Getenv("PPROF_PORT", "6060") - go func() { - pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) - log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) - if err = http.ListenAndServe(pprofAddr, nil); err != nil { - log.Printf("pprof server error: %v", err) - } - }() - } - sshConfig := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), @@ -88,16 +72,11 @@ func main() { signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) var grpcClient client.Client - if isNodeMode { - grpcHost := config.Getenv("GRPC_ADDRESS", "localhost") - grpcPort := config.Getenv("GRPC_PORT", "8080") - grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort) - nodeToken := config.Getenv("NODE_TOKEN", "") - if nodeToken == "" { - log.Fatalf("NODE_TOKEN is required in node mode") - } - grpcClient, err = client.New(grpcAddr, sessionRegistry) + if conf.Mode() == types.ServerModeNODE { + grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort()) + + grpcClient, err = client.New(conf, grpcAddr, sessionRegistry) if err != nil { log.Fatalf("failed to create grpc client: %v", err) } @@ -110,46 +89,15 @@ func main() { healthCancel() go func() { - identity := config.Getenv("DOMAIN", "localhost") - if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil { + if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { errChan <- fmt.Errorf("failed to subscribe to events: %w", err) } }() } - portManager := port.New() - rawRange := config.Getenv("ALLOWED_PORTS", "") - if rawRange != "" { - splitRange := strings.Split(rawRange, "-") - if len(splitRange) == 2 { - var start, end uint64 - start, err = strconv.ParseUint(splitRange[0], 10, 16) - if err != nil { - log.Fatalf("Failed to parse start port: %s", err) - } - - end, err = strconv.ParseUint(splitRange[1], 10, 16) - if err != nil { - log.Fatalf("Failed to parse end port: %s", err) - } - - if err = portManager.AddRange(uint16(start), uint16(end)); err != nil { - log.Fatalf("Failed to add port range: %s", err) - } - log.Printf("PortRegistry range configured: %d-%d", start, end) - } else { - log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) - } - } - - tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true" - redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" - go func() { - httpPort := config.Getenv("HTTP_PORT", "8080") - var httpListener net.Listener - httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS) + httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect()) httpListener, err = httpserver.Listen() if err != nil { errChan <- fmt.Errorf("failed to start http server: %w", err) @@ -162,37 +110,52 @@ func main() { } }() - if tlsEnabled { + if conf.TLSEnabled() { go func() { - httpsPort := config.Getenv("HTTPS_PORT", "8443") - domain := config.Getenv("DOMAIN", "localhost") - - var httpListener net.Listener - httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS) - httpListener, err = httpserver.Listen() + var httpsListener net.Listener + tlsConfig, _ := transport.NewTLSConfig(conf) + httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig) + httpsListener, err = httpsServer.Listen() if err != nil { errChan <- fmt.Errorf("failed to start http server: %w", err) return } - err = httpserver.Serve(httpListener) + err = httpsServer.Serve(httpsListener) if err != nil { errChan <- fmt.Errorf("error when serving http server: %w", err) return } }() } + + portManager := port.New() + err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()) + if err != nil { + log.Fatalf("Failed to initialize port manager: %s", err) + return + } var app server.Server go func() { - sshPort := config.Getenv("PORT", "2200") - app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort) + app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort()) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return } app.Start() + }() + if conf.PprofEnabled() { + go func() { + pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort()) + log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) + if err = http.ListenAndServe(pprofAddr, nil); err != nil { + log.Printf("pprof server error: %v", err) + } + }() + } + select { case err = <-errChan: log.Printf("error happen : %s", err) diff --git a/server/server.go b/server/server.go index 185d051..f47c579 100644 --- a/server/server.go +++ b/server/server.go @@ -7,6 +7,7 @@ import ( "log" "net" "time" + "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/port" "tunnel_pls/internal/registry" @@ -20,24 +21,26 @@ type Server interface { Close() error } type server struct { + config config.Config sshPort string sshListener net.Listener - config *ssh.ServerConfig + sshConfig *ssh.ServerConfig grpcClient client.Client sessionRegistry registry.Registry portRegistry port.Port } -func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { +func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) if err != nil { return nil, err } return &server{ + config: config, sshPort: sshPort, sshListener: listener, - config: sshConfig, + sshConfig: sshConfig, grpcClient: grpcClient, sessionRegistry: sessionRegistry, portRegistry: portRegistry, @@ -66,7 +69,7 @@ func (s *server) Close() error { } func (s *server) handleConnection(conn net.Conn) { - sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) + sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig) if err != nil { log.Printf("failed to establish SSH connection: %v", err) err = conn.Close() @@ -92,7 +95,7 @@ func (s *server) handleConnection(conn net.Conn) { cancel() } log.Println("SSH connection established:", sshConn.User()) - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) + sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index ff2abde..c602565 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -18,37 +18,6 @@ import ( "golang.org/x/crypto/ssh" ) -var bufferPool = sync.Pool{ - New: func() interface{} { - bufSize := config.GetBufferSize() - return make([]byte, bufSize) - }, -} - -func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { - buf := bufferPool.Get().([]byte) - defer bufferPool.Put(buf) - return io.CopyBuffer(dst, src, buf) -} - -type forwarder struct { - listener net.Listener - tunnelType types.TunnelType - forwardedPort uint16 - slug slug.Slug - conn ssh.Conn -} - -func New(slug slug.Slug, conn ssh.Conn) Forwarder { - return &forwarder{ - listener: nil, - tunnelType: types.UNKNOWN, - forwardedPort: 0, - slug: slug, - conn: conn, - } -} - type Forwarder interface { SetType(tunnelType types.TunnelType) SetForwardedPort(port uint16) @@ -62,6 +31,36 @@ type Forwarder interface { WriteBadGatewayResponse(dst io.Writer) Close() error } +type forwarder struct { + listener net.Listener + tunnelType types.TunnelType + forwardedPort uint16 + slug slug.Slug + conn ssh.Conn + bufferPool sync.Pool +} + +func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder { + return &forwarder{ + listener: nil, + tunnelType: types.TunnelTypeUNKNOWN, + forwardedPort: 0, + slug: slug, + conn: conn, + bufferPool: sync.Pool{ + New: func() interface{} { + bufSize := config.BufferSize() + return make([]byte, bufSize) + }, + }, + } +} + +func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { + buf := f.bufferPool.Get().([]byte) + defer f.bufferPool.Put(buf) + return io.CopyBuffer(dst, src, buf) +} func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { type channelResult struct { @@ -107,7 +106,7 @@ func closeWriter(w io.Writer) error { func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error { var errs []error - _, err := copyWithBuffer(dst, src) + _, err := f.copyWithBuffer(dst, src) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err)) } diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 3c02dae..5f68102 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -18,9 +18,9 @@ import ( ) type Interaction interface { - Mode() types.Mode + Mode() types.InteractiveMode SetChannel(channel ssh.Channel) - SetMode(m types.Mode) + SetMode(m types.InteractiveMode) SetWH(w, h int) Start() Redraw() @@ -39,6 +39,7 @@ type Forwarder interface { type CloseFunc func() error type interaction struct { + config config.Config channel ssh.Channel slug slug.Slug forwarder Forwarder @@ -48,14 +49,14 @@ type interaction struct { program *tea.Program ctx context.Context cancel context.CancelFunc - mode types.Mode + mode types.InteractiveMode } -func (i *interaction) SetMode(m types.Mode) { +func (i *interaction) SetMode(m types.InteractiveMode) { i.mode = m } -func (i *interaction) Mode() types.Mode { +func (i *interaction) Mode() types.InteractiveMode { return i.mode } @@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) { } } -func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { +func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { ctx, cancel := context.WithCancel(context.Background()) return &interaction{ + config: config, channel: nil, slug: slug, forwarder: forwarder, @@ -174,14 +176,13 @@ func (m *model) View() string { } func (i *interaction) Start() { - if i.mode == types.HEADLESS { + if i.mode == types.InteractiveModeHEADLESS { return } lipgloss.SetColorProfile(termenv.TrueColor) - domain := config.Getenv("DOMAIN", "localhost") protocol := "http" - if config.Getenv("TLS_ENABLED", "false") == "true" { + if i.config.TLSEnabled() { protocol = "https" } @@ -209,7 +210,7 @@ func (i *interaction) Start() { ti.Width = 50 m := &model{ - domain: domain, + domain: i.config.Domain(), protocol: protocol, tunnelType: tunnelType, port: port, diff --git a/session/interaction/model.go b/session/interaction/model.go index 24b4d26..189b0a1 100644 --- a/session/interaction/model.go +++ b/session/interaction/model.go @@ -41,7 +41,7 @@ type model struct { } func (m *model) getTunnelURL() string { - if m.tunnelType == types.HTTP { + if m.tunnelType == types.TunnelTypeHTTP { return buildURL(m.protocol, m.interaction.slug.String(), m.domain) } return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) diff --git a/session/interaction/slug.go b/session/interaction/slug.go index 6c6a97b..08c7c7d 100644 --- a/session/interaction/slug.go +++ b/session/interaction/slug.go @@ -15,7 +15,7 @@ import ( func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { var cmd tea.Cmd - if m.tunnelType != types.HTTP { + if m.tunnelType != types.TunnelTypeHTTP { m.editingSlug = false m.slugError = "" return m, tea.Batch(tea.ClearScreen, textinput.Blink) @@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { inputValue := m.slugInput.Value() if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{ Id: m.interaction.slug.String(), - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }, types.SessionKey{ Id: inputValue, - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }); err != nil { m.slugError = err.Error() return m, nil @@ -130,7 +130,7 @@ func (m *model) slugView() string { b.WriteString(titleStyle.Render(title)) b.WriteString("\n\n") - if m.tunnelType != types.HTTP { + if m.tunnelType != types.TunnelTypeHTTP { warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) warningBoxStyle := lipgloss.NewStyle(). Foreground(lipgloss.Color("#FFA500")). @@ -145,9 +145,9 @@ func (m *model) slugView() string { var warningText string if isVeryCompact { - warningText = "⚠️ TCP tunnels don't support custom subdomains." + warningText = "⚠️ TunnelTypeTCP tunnels don't support custom subdomains." } else { - warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization." + warningText = "⚠️ TunnelTypeTCP tunnels cannot have custom subdomains. Only TunnelTypeHTTP/HTTPS tunnels support subdomain customization." } b.WriteString(warningBoxStyle.Render(warningText)) b.WriteString("\n\n") diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 7a2fcaf..e4ce44f 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -24,7 +24,7 @@ type SessionRegistry interface { } type lifecycle struct { - status types.Status + status types.SessionStatus conn ssh.Conn channel ssh.Channel forwarder Forwarder @@ -37,7 +37,7 @@ type lifecycle struct { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle { return &lifecycle{ - status: types.INITIALIZING, + status: types.SessionStatusINITIALIZING, conn: conn, channel: nil, forwarder: forwarder, @@ -54,7 +54,7 @@ type Lifecycle interface { PortRegistry() portUtil.Port User() string SetChannel(channel ssh.Channel) - SetStatus(status types.Status) + SetStatus(status types.SessionStatus) IsActive() bool StartedAt() time.Time Close() error @@ -74,9 +74,9 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) { func (l *lifecycle) Connection() ssh.Conn { return l.conn } -func (l *lifecycle) SetStatus(status types.Status) { +func (l *lifecycle) SetStatus(status types.SessionStatus) { l.status = status - if status == types.RUNNING && l.startedAt.IsZero() { + if status == types.SessionStatusRUNNING && l.startedAt.IsZero() { l.startedAt = time.Now() } } @@ -112,7 +112,7 @@ func (l *lifecycle) Close() error { } l.sessionRegistry.Remove(key) - if tunnelType == types.TCP { + if tunnelType == types.TunnelTypeTCP { if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { firstErr = err } @@ -122,7 +122,7 @@ func (l *lifecycle) Close() error { } func (l *lifecycle) IsActive() bool { - return l.status == types.RUNNING + return l.status == types.SessionStatusRUNNING } func (l *lifecycle) StartedAt() time.Time { diff --git a/session/session.go b/session/session.go index d113084..65bbc54 100644 --- a/session/session.go +++ b/session/session.go @@ -37,6 +37,7 @@ type Session interface { } type session struct { + config config.Config initialReq <-chan *ssh.Request sshChan <-chan ssh.NewChannel lifecycle lifecycle.Lifecycle @@ -48,13 +49,14 @@ type session struct { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { +func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { slugManager := slug.New() - forwarderManager := forwarder.New(slugManager, conn) + forwarderManager := forwarder.New(config, slugManager, conn) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) - interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) + interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) return &session{ + config: config, initialReq: initialReq, sshChan: sshChan, lifecycle: lifecycleManager, @@ -83,12 +85,12 @@ func (s *session) Slug() slug.Slug { func (s *session) Detail() *types.Detail { tunnelTypeMap := map[types.TunnelType]string{ - types.HTTP: "HTTP", - types.TCP: "TCP", + types.TunnelTypeHTTP: "TunnelTypeHTTP", + types.TunnelTypeTCP: "TunnelTypeTCP", } tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()] if !ok { - tunnelType = "UNKNOWN" + tunnelType = "TunnelTypeUNKNOWN" } return &types.Detail{ @@ -131,7 +133,7 @@ func (s *session) setupSessionMode() error { } return s.setupInteractiveMode(channel) case <-time.After(500 * time.Millisecond): - s.interaction.SetMode(types.HEADLESS) + s.interaction.SetMode(types.InteractiveModeHEADLESS) return nil } } @@ -152,13 +154,13 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error { s.lifecycle.SetChannel(ch) s.interaction.SetChannel(ch) - s.interaction.SetMode(types.INTERACTIVE) + s.interaction.SetMode(types.InteractiveModeINTERACTIVE) return nil } func (s *session) handleMissingForwardRequest() error { - err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) + err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain, s.config.SSHPort)) if err != nil { return err } @@ -169,8 +171,8 @@ func (s *session) handleMissingForwardRequest() error { } func (s *session) shouldRejectUnauthorized() bool { - return s.interaction.Mode() == types.HEADLESS && - config.Getenv("MODE", "standalone") == "standalone" && + return s.interaction.Mode() == types.InteractiveModeHEADLESS && + s.config.Mode() == types.ServerModeSTANDALONE && s.lifecycle.User() == "UNAUTHORIZED" } @@ -318,7 +320,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen s.forwarder.SetType(tunnelType) s.forwarder.SetForwardedPort(portToBind) s.slug.Set(slug) - s.lifecycle.SetStatus(types.RUNNING) + s.lifecycle.SetStatus(types.SessionStatusRUNNING) if listener != nil { s.forwarder.SetListener(listener) @@ -348,12 +350,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { if err != nil { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) } - key := types.SessionKey{Id: randomString, Type: types.HTTP} + key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP} if !s.registry.Register(key, s) { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString)) } - err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id) + err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id) if err != nil { return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } @@ -371,12 +373,12 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } - key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} + key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP} if !s.registry.Register(key, s) { - return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id)) + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id)) } - err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id) + err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id) if err != nil { return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } diff --git a/types/types.go b/types/types.go index b91dffb..8e7d1e5 100644 --- a/types/types.go +++ b/types/types.go @@ -2,26 +2,33 @@ package types import "time" -type Status int +type SessionStatus int const ( - INITIALIZING Status = iota - RUNNING + SessionStatusINITIALIZING SessionStatus = iota + SessionStatusRUNNING ) -type Mode int +type InteractiveMode int const ( - INTERACTIVE Mode = iota - HEADLESS + InteractiveModeINTERACTIVE InteractiveMode = iota + 1 + InteractiveModeHEADLESS ) type TunnelType int const ( - UNKNOWN TunnelType = iota - HTTP - TCP + TunnelTypeUNKNOWN TunnelType = iota + TunnelTypeHTTP + TunnelTypeTCP +) + +type ServerMode int + +const ( + ServerModeSTANDALONE = iota + 1 + ServerModeNODE ) type SessionKey struct { @@ -37,7 +44,7 @@ type Detail struct { StartedAt time.Time `json:"started_at,omitempty"` } -var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + +var BadGatewayResponse = []byte("TunnelTypeHTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" + "Bad Gateway") -- 2.49.1 From 1408b80917980700e89a4da62a1dfcb6971d5459 Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 21 Jan 2026 21:05:10 +0700 Subject: [PATCH 09/11] ci: add sonarqube scan --- .gitea/workflows/sonarqube.yml | 20 ++++++++++++++++++++ sonar-project.properties | 1 + 2 files changed, 21 insertions(+) create mode 100644 .gitea/workflows/sonarqube.yml create mode 100644 sonar-project.properties diff --git a/.gitea/workflows/sonarqube.yml b/.gitea/workflows/sonarqube.yml new file mode 100644 index 0000000..9c672ac --- /dev/null +++ b/.gitea/workflows/sonarqube.yml @@ -0,0 +1,20 @@ +on: + push: + pull_request: + types: [opened, synchronize, reopened] + +name: SonarQube Scan +jobs: + sonarqube: + name: SonarQube Trigger + runs-on: ubuntu-latest + steps: + - name: Checking out + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: SonarQube Scan + uses: SonarSource/sonarqube-scan-action@v7.0.0 + env: + SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }} + SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..277a293 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1 @@ +sonar.projectKey=tunnel-please \ No newline at end of file -- 2.49.1 From 9f4c24a3f37a47933db0699eea1b3ce4f2af6277 Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 21 Jan 2026 21:55:38 +0700 Subject: [PATCH 10/11] refactor(lifecycle): reorder resource closing and simplify Close() - Close channel and connection first, then remove session - Close forwarded port and forwarder at the end for TCP tunnels - Aggregate all errors using errors.Join instead of failing early --- session/lifecycle/lifecycle.go | 40 +++++++++++++++------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index e4ce44f..f9f9d6e 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -2,8 +2,6 @@ package lifecycle import ( "errors" - "io" - "net" "time" portUtil "tunnel_pls/internal/port" @@ -81,28 +79,23 @@ func (l *lifecycle) SetStatus(status types.SessionStatus) { } } +func closeIfNotNil(c interface{ Close() error }) error { + if c != nil { + return c.Close() + } + return nil +} + func (l *lifecycle) Close() error { - var firstErr error + var errs []error tunnelType := l.forwarder.TunnelType() - if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - firstErr = err + if err := closeIfNotNil(l.channel); err != nil { + errs = append(errs, err) } - if l.channel != nil { - if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - if firstErr == nil { - firstErr = err - } - } - } - - if l.conn != nil { - if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - if firstErr == nil { - firstErr = err - } - } + if err := closeIfNotNil(l.conn); err != nil { + errs = append(errs, err) } clientSlug := l.slug.String() @@ -113,12 +106,15 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TunnelTypeTCP { - if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { - firstErr = err + if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil { + errs = append(errs, err) + } + if err := l.forwarder.Close(); err != nil { + errs = append(errs, err) } } - return firstErr + return errors.Join(errs...) } func (l *lifecycle) IsActive() bool { -- 2.49.1 From 634c8321efae97c9bff8ebba2db82faee0210741 Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 21 Jan 2026 22:11:24 +0700 Subject: [PATCH 11/11] refactor(registry): define reusable constant errors - Introduced package-level error variables in registry to replace repeated fmt.Errorf calls - Added errors like ErrSessionNotFound, ErrSlugInUse, ErrInvalidSlug, ErrForbiddenSlug, ErrSlugChangeNotAllowed, and ErrSlugUnchanged --- internal/registry/registry.go | 27 ++++++++++++++++++--------- internal/transport/httphandler.go | 8 ++++---- internal/transport/tls.go | 2 +- session/interaction/slug.go | 4 ++-- session/session.go | 2 +- types/types.go | 2 +- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/internal/registry/registry.go b/internal/registry/registry.go index 86898b0..89cac48 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -34,6 +34,15 @@ type registry struct { slugIndex map[Key]string } +var ( + ErrSessionNotFound = fmt.Errorf("session not found") + ErrSlugInUse = fmt.Errorf("slug already in use") + ErrInvalidSlug = fmt.Errorf("invalid slug") + ErrForbiddenSlug = fmt.Errorf("forbidden slug") + ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type") + ErrSlugUnchanged = fmt.Errorf("slug is unchanged") +) + func NewRegistry() Registry { return ®istry{ byUser: make(map[string]map[Key]Session), @@ -47,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) { userID, ok := r.slugIndex[key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, ErrSessionNotFound } client, ok := r.byUser[userID][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, ErrSessionNotFound } return client, nil } @@ -63,37 +72,37 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error client, ok := r.byUser[user][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, ErrSessionNotFound } return client, nil } func (r *registry) Update(user string, oldKey, newKey Key) error { if oldKey.Type != newKey.Type { - return fmt.Errorf("tunnel type cannot change") + return ErrSlugUnchanged } if newKey.Type != types.TunnelTypeHTTP { - return fmt.Errorf("non http tunnel cannot change slug") + return ErrSlugChangeNotAllowed } if isForbiddenSlug(newKey.Id) { - return fmt.Errorf("this subdomain is reserved. Please choose a different one") + return ErrForbiddenSlug } if !isValidSlug(newKey.Id) { - return fmt.Errorf("invalid subdomain. Follow the rules") + return ErrInvalidSlug } r.mu.Lock() defer r.mu.Unlock() if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { - return fmt.Errorf("someone already uses this subdomain") + return ErrSlugInUse } client, ok := r.byUser[user][oldKey] if !ok { - return fmt.Errorf("session not found") + return ErrSessionNotFound } delete(r.byUser[user], oldKey) diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index b6f128d..8bab4a0 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -34,7 +34,7 @@ func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTL } func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error { - _, err := conn.Write([]byte(fmt.Sprintf("TunnelTypeHTTP/1.1 %d Moved Permanently\r\n", status) + + _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + fmt.Sprintf("Location: %s", location) + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -46,7 +46,7 @@ func (hh *httpHandler) redirect(conn net.Conn, status int, location string) erro } func (hh *httpHandler) badRequest(conn net.Conn) error { - if _, err := conn.Write([]byte("TunnelTypeHTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { return err } return nil @@ -87,7 +87,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { defer func(hw stream.HTTP) { err = hw.Close() if err != nil { - log.Printf("Error closing TunnelTypeHTTP stream: %v", err) + log.Printf("Error closing HTTP stream: %v", err) } }(hw) hh.forwardRequest(hw, reqhf, sshSession) @@ -118,7 +118,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { } _, err := conn.Write([]byte( - "TunnelTypeHTTP/1.1 200 OK\r\n" + + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + "Access-Control-Allow-Origin: *\r\n" + diff --git a/internal/transport/tls.go b/internal/transport/tls.go index 6824a54..877afb4 100644 --- a/internal/transport/tls.go +++ b/internal/transport/tls.go @@ -67,7 +67,7 @@ func NewTLSConfig(config config.Config) (*tls.Config, error) { tm.useCertMagic = false tm.startCertWatcher() } else { - log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain, config.Domain) + log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain()) if err := tm.initCertMagic(); err != nil { initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) return diff --git a/session/interaction/slug.go b/session/interaction/slug.go index 08c7c7d..2b871d4 100644 --- a/session/interaction/slug.go +++ b/session/interaction/slug.go @@ -145,9 +145,9 @@ func (m *model) slugView() string { var warningText string if isVeryCompact { - warningText = "⚠️ TunnelTypeTCP tunnels don't support custom subdomains." + warningText = "⚠️ TCP tunnels don't support custom subdomains." } else { - warningText = "⚠️ TunnelTypeTCP tunnels cannot have custom subdomains. Only TunnelTypeHTTP/HTTPS tunnels support subdomain customization." + warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization." } b.WriteString(warningBoxStyle.Render(warningText)) b.WriteString("\n\n") diff --git a/session/session.go b/session/session.go index 65bbc54..b1895ab 100644 --- a/session/session.go +++ b/session/session.go @@ -160,7 +160,7 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error { } func (s *session) handleMissingForwardRequest() error { - err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain, s.config.SSHPort)) + err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort())) if err != nil { return err } diff --git a/types/types.go b/types/types.go index 8e7d1e5..34ccfb4 100644 --- a/types/types.go +++ b/types/types.go @@ -44,7 +44,7 @@ type Detail struct { StartedAt time.Time `json:"started_at,omitempty"` } -var BadGatewayResponse = []byte("TunnelTypeHTTP/1.1 502 Bad Gateway\r\n" + +var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" + "Bad Gateway") -- 2.49.1