diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index ffcc7d9..0874afe 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -8,10 +8,9 @@ import ( "log" "time" "tunnel_pls/internal/config" + "tunnel_pls/internal/registry" "tunnel_pls/types" - "tunnel_pls/session" - proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -32,13 +31,13 @@ type Client interface { type client struct { conn *grpc.ClientConn address string - sessionRegistry session.Registry + sessionRegistry registry.Registry eventService proto.EventServiceClient authorizeConnectionService proto.UserServiceClient closing bool } -func New(address string, sessionRegistry session.Registry) (Client, error) { +func New(address string, sessionRegistry registry.Registry) (Client, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/internal/httpheader/header.go b/internal/http/header/header.go similarity index 96% rename from internal/httpheader/header.go rename to internal/http/header/header.go index ccd1bed..a5e52b3 100644 --- a/internal/httpheader/header.go +++ b/internal/http/header/header.go @@ -1,4 +1,4 @@ -package httpheader +package header type ResponseHeader interface { Value(key string) string diff --git a/internal/httpheader/parser.go b/internal/http/header/parser.go similarity index 99% rename from internal/httpheader/parser.go rename to internal/http/header/parser.go index 3325ae5..861c49e 100644 --- a/internal/httpheader/parser.go +++ b/internal/http/header/parser.go @@ -1,4 +1,4 @@ -package httpheader +package header import ( "bufio" diff --git a/internal/httpheader/request.go b/internal/http/header/request.go similarity index 90% rename from internal/httpheader/request.go rename to internal/http/header/request.go index ae63340..b05f699 100644 --- a/internal/httpheader/request.go +++ b/internal/http/header/request.go @@ -1,11 +1,11 @@ -package httpheader +package header import ( "bufio" "fmt" ) -func NewRequestHeader(r interface{}) (RequestHeader, error) { +func NewRequest(r interface{}) (RequestHeader, error) { switch v := r.(type) { case []byte: return parseHeadersFromBytes(v) diff --git a/internal/httpheader/response.go b/internal/http/header/response.go similarity index 89% rename from internal/httpheader/response.go rename to internal/http/header/response.go index 63ad352..b6305d4 100644 --- a/internal/httpheader/response.go +++ b/internal/http/header/response.go @@ -1,11 +1,11 @@ -package httpheader +package header import ( "bytes" "fmt" ) -func NewResponseHeader(headerData []byte) (ResponseHeader, error) { +func NewResponse(headerData []byte) (ResponseHeader, error) { header := &responseHeader{ startLine: nil, headers: make(map[string]string, 16), diff --git a/internal/http/stream/parser.go b/internal/http/stream/parser.go new file mode 100644 index 0000000..b1d8277 --- /dev/null +++ b/internal/http/stream/parser.go @@ -0,0 +1,29 @@ +package stream + +import "bytes" + +func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) { + headerByte := data[:delimiterIdx+len(DELIMITER)] + body := data[delimiterIdx+len(DELIMITER):] + return headerByte, body +} + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + colonIdx := bytes.IndexByte(line, ':') + if colonIdx <= 0 { + return false + } + } + return true +} diff --git a/internal/http/stream/reader.go b/internal/http/stream/reader.go new file mode 100644 index 0000000..f7c99d4 --- /dev/null +++ b/internal/http/stream/reader.go @@ -0,0 +1,50 @@ +package stream + +import ( + "bytes" + "tunnel_pls/internal/http/header" +) + +func (hs *http) Read(p []byte) (int, error) { + tmp := make([]byte, len(p)) + read, err := hs.reader.Read(tmp) + if read == 0 && err != nil { + return 0, err + } + + tmp = tmp[:read] + + headerEndIdx := bytes.Index(tmp, DELIMITER) + if headerEndIdx == -1 { + return handleNoDelimiter(p, tmp, err) + } + + headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx) + + if !isHTTPHeader(headerByte) { + copy(p, tmp) + return read, nil + } + + return hs.processHTTPRequest(p, headerByte, bodyByte) +} + +func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) { + reqhf, err := header.NewRequest(headerByte) + if err != nil { + return 0, err + } + + if err = hs.ApplyRequestMiddlewares(reqhf); err != nil { + return 0, err + } + + hs.reqHeader = reqhf + combined := append(reqhf.Finalize(), bodyByte...) + return copy(p, combined), nil +} + +func handleNoDelimiter(p, tmp []byte, err error) (int, error) { + copy(p, tmp) + return len(tmp), err +} diff --git a/internal/http/stream/stream.go b/internal/http/stream/stream.go new file mode 100644 index 0000000..97d2752 --- /dev/null +++ b/internal/http/stream/stream.go @@ -0,0 +1,103 @@ +package stream + +import ( + "io" + "log" + "net" + "regexp" + "tunnel_pls/internal/http/header" + "tunnel_pls/internal/middleware" +) + +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +type HTTP interface { + io.ReadWriteCloser + CloseWrite() error + RemoteAddr() net.Addr + UseResponseMiddleware(mw middleware.ResponseMiddleware) + UseRequestMiddleware(mw middleware.RequestMiddleware) + SetRequestHeader(header header.RequestHeader) + RequestMiddlewares() []middleware.RequestMiddleware + ResponseMiddlewares() []middleware.ResponseMiddleware + ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error + ApplyRequestMiddlewares(reqhf header.RequestHeader) error +} + +type http struct { + remoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + respHeader header.ResponseHeader + reqHeader header.RequestHeader + respMW []middleware.ResponseMiddleware + reqMW []middleware.RequestMiddleware +} + +func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP { + return &http{ + remoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + } +} + +func (hs *http) RemoteAddr() net.Addr { + return hs.remoteAddr +} + +func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) { + hs.respMW = append(hs.respMW, mw) +} + +func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) { + hs.reqMW = append(hs.reqMW, mw) +} + +func (hs *http) SetRequestHeader(header header.RequestHeader) { + hs.reqHeader = header +} + +func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware { + return hs.reqMW +} + +func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware { + return hs.respMW +} + +func (hs *http) Close() error { + return hs.writer.(io.Closer).Close() +} + +func (hs *http) CloseWrite() error { + if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok { + return closer.CloseWrite() + } + return hs.Close() +} + +func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error { + for _, m := range hs.RequestMiddlewares() { + if err := m.HandleRequest(reqhf); err != nil { + log.Printf("Error when applying request middleware: %v", err) + return err + } + } + return nil +} + +func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error { + for _, m := range hs.ResponseMiddlewares() { + if err := m.HandleResponse(resphf, bodyByte); err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return err + } + } + return nil +} diff --git a/internal/http/stream/writer.go b/internal/http/stream/writer.go new file mode 100644 index 0000000..05e438b --- /dev/null +++ b/internal/http/stream/writer.go @@ -0,0 +1,88 @@ +package stream + +import ( + "bytes" + "tunnel_pls/internal/http/header" +) + +func (hs *http) Write(p []byte) (int, error) { + if hs.shouldBypassBuffering(p) { + hs.respHeader = nil + } + + if hs.respHeader != nil { + return hs.writer.Write(p) + } + + hs.buf = append(hs.buf, p...) + + headerEndIdx := bytes.Index(hs.buf, DELIMITER) + if headerEndIdx == -1 { + return len(p), nil + } + + return hs.processBufferedResponse(p, headerEndIdx) +} + +func (hs *http) shouldBypassBuffering(p []byte) bool { + return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" +} + +func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) { + headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx) + + if !isHTTPHeader(headerByte) { + return hs.writeRawBuffer() + } + + if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil { + return 0, err + } + + hs.buf = nil + return len(p), nil +} + +func (hs *http) writeRawBuffer() (int, error) { + _, err := hs.writer.Write(hs.buf) + length := len(hs.buf) + hs.buf = nil + if err != nil { + return 0, err + } + return length, nil +} + +func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error { + resphf, err := header.NewResponse(headerByte) + if err != nil { + return err + } + + if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil { + return err + } + + hs.respHeader = resphf + finalHeader := resphf.Finalize() + + if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil { + return err + } + + return nil +} + +func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error { + if _, err := hs.writer.Write(header); err != nil { + return err + } + + if len(bodyByte) > 0 { + if _, err := hs.writer.Write(bodyByte); err != nil { + return err + } + } + + return nil +} diff --git a/internal/middleware/forwardedfor.go b/internal/middleware/forwardedfor.go new file mode 100644 index 0000000..6a744a6 --- /dev/null +++ b/internal/middleware/forwardedfor.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net" + "tunnel_pls/internal/http/header" +) + +type ForwardedFor struct { + addr net.Addr +} + +func NewForwardedFor(addr net.Addr) *ForwardedFor { + return &ForwardedFor{addr: addr} +} + +func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error { + host, _, err := net.SplitHostPort(ff.addr.String()) + if err != nil { + return err + } + header.Set("X-Forwarded-For", host) + return nil +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..4e3b163 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,13 @@ +package middleware + +import ( + "tunnel_pls/internal/http/header" +) + +type RequestMiddleware interface { + HandleRequest(header header.RequestHeader) error +} + +type ResponseMiddleware interface { + HandleResponse(header header.ResponseHeader, body []byte) error +} diff --git a/internal/middleware/tunnelfingerprint.go b/internal/middleware/tunnelfingerprint.go new file mode 100644 index 0000000..68171c2 --- /dev/null +++ b/internal/middleware/tunnelfingerprint.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "tunnel_pls/internal/http/header" +) + +type TunnelFingerprint struct{} + +func NewTunnelFingerprint() *TunnelFingerprint { + return &TunnelFingerprint{} +} + +func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error { + header.Set("Server", "Tunnel Please") + return nil +} diff --git a/internal/port/port.go b/internal/port/port.go index 01ecf96..6c60fbb 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -6,37 +6,37 @@ import ( "sync" ) -type Registry interface { - AddPortRange(startPort, endPort uint16) error - GetUnassignedPort() (uint16, bool) - SetPortStatus(port uint16, assigned bool) error - ClaimPort(port uint16) (claimed bool) +type Port interface { + AddRange(startPort, endPort uint16) error + Unassigned() (uint16, bool) + SetStatus(port uint16, assigned bool) error + Claim(port uint16) (claimed bool) } -type registry struct { +type port struct { mu sync.RWMutex ports map[uint16]bool sortedPorts []uint16 } -func New() Registry { - return ®istry{ +func New() Port { + return &port{ ports: make(map[uint16]bool), sortedPorts: []uint16{}, } } -func (pm *registry) AddPortRange(startPort, endPort uint16) error { +func (pm *port) AddRange(startPort, endPort uint16) error { pm.mu.Lock() defer pm.mu.Unlock() if startPort > endPort { return fmt.Errorf("start port cannot be greater than end port") } - for port := startPort; port <= endPort; port++ { - if _, exists := pm.ports[port]; !exists { - pm.ports[port] = false - pm.sortedPorts = append(pm.sortedPorts, port) + for index := startPort; index <= endPort; index++ { + if _, exists := pm.ports[index]; !exists { + pm.ports[index] = false + pm.sortedPorts = append(pm.sortedPorts, index) } } sort.Slice(pm.sortedPorts, func(i, j int) bool { @@ -45,19 +45,19 @@ func (pm *registry) AddPortRange(startPort, endPort uint16) error { return nil } -func (pm *registry) GetUnassignedPort() (uint16, bool) { +func (pm *port) Unassigned() (uint16, bool) { pm.mu.Lock() defer pm.mu.Unlock() - for _, port := range pm.sortedPorts { - if !pm.ports[port] { - return port, true + for _, index := range pm.sortedPorts { + if !pm.ports[index] { + return index, true } } return 0, false } -func (pm *registry) SetPortStatus(port uint16, assigned bool) error { +func (pm *port) SetStatus(port uint16, assigned bool) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -65,7 +65,7 @@ func (pm *registry) SetPortStatus(port uint16, assigned bool) error { return nil } -func (pm *registry) ClaimPort(port uint16) (claimed bool) { +func (pm *port) Claim(port uint16) (claimed bool) { pm.mu.Lock() defer pm.mu.Unlock() diff --git a/session/registry.go b/internal/registry/registry.go similarity index 90% rename from session/registry.go rename to internal/registry/registry.go index 6698cf1..22e590a 100644 --- a/session/registry.go +++ b/internal/registry/registry.go @@ -1,13 +1,25 @@ -package session +package registry import ( "fmt" "sync" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" "tunnel_pls/types" ) type Key = types.SessionKey +type Session interface { + Lifecycle() lifecycle.Lifecycle + Interaction() interaction.Interaction + Forwarder() forwarder.Forwarder + Slug() slug.Slug + Detail() *types.Detail +} + type Registry interface { Get(key Key) (session Session, err error) GetWithUser(user string, key Key) (session Session, err error) @@ -35,12 +47,12 @@ func (r *registry) Get(key Key) (session Session, err error) { userID, ok := r.slugIndex[key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, fmt.Errorf("Session not found") } client, ok := r.byUser[userID][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, fmt.Errorf("Session not found") } return client, nil } @@ -51,7 +63,7 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error client, ok := r.byUser[user][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, fmt.Errorf("Session not found") } return client, nil } @@ -81,7 +93,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { } client, ok := r.byUser[user][oldKey] if !ok { - return fmt.Errorf("session not found") + return fmt.Errorf("Session not found") } delete(r.byUser[user], oldKey) @@ -97,7 +109,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { return nil } -func (r *registry) Register(key Key, session Session) (success bool) { +func (r *registry) Register(key Key, userSession Session) (success bool) { r.mu.Lock() defer r.mu.Unlock() @@ -105,12 +117,12 @@ func (r *registry) Register(key Key, session Session) (success bool) { return false } - userID := session.Lifecycle().User() + userID := userSession.Lifecycle().User() if r.byUser[userID] == nil { r.byUser[userID] = make(map[Key]Session) } - r.byUser[userID][key] = session + r.byUser[userID][key] = userSession r.slugIndex[key] = userID return true } diff --git a/internal/transport/http.go b/internal/transport/http.go new file mode 100644 index 0000000..bf698ab --- /dev/null +++ b/internal/transport/http.go @@ -0,0 +1,40 @@ +package transport + +import ( + "errors" + "log" + "net" + "tunnel_pls/internal/registry" +) + +type httpServer struct { + handler *httpHandler + port string +} + +func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { + return &httpServer{ + handler: newHTTPHandler(sessionRegistry, redirectTLS), + port: port, + } +} + +func (ht *httpServer) Listen() (net.Listener, error) { + return net.Listen("tcp", ":"+ht.port) +} + +func (ht *httpServer) Serve(listener net.Listener) error { + log.Printf("HTTP server is starting on port %s", ht.port) + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return err + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go ht.handler.handler(conn, false) + } +} diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go new file mode 100644 index 0000000..0b22e48 --- /dev/null +++ b/internal/transport/httphandler.go @@ -0,0 +1,227 @@ +package transport + +import ( + "bufio" + "errors" + "fmt" + "log" + "net" + "net/http" + "strings" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/http/header" + "tunnel_pls/internal/http/stream" + "tunnel_pls/internal/middleware" + "tunnel_pls/internal/registry" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type httpHandler struct { + sessionRegistry registry.Registry + redirectTLS bool +} + +func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { + return &httpHandler{ + sessionRegistry: sessionRegistry, + redirectTLS: redirectTLS, + } +} + +func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error { + _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + + fmt.Sprintf("Location: %s", location) + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "\r\n")) + if err != nil { + return err + } + return nil +} + +func (hh *httpHandler) badRequest(conn net.Conn) error { + if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + return err + } + return nil +} + +func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { + defer hh.closeConnection(conn) + + dstReader := bufio.NewReader(conn) + reqhf, err := header.NewRequest(dstReader) + if err != nil { + log.Printf("Error creating request header: %v", err) + return + } + + slug, err := hh.extractSlug(reqhf) + if err != nil { + _ = hh.badRequest(conn) + return + } + + if hh.shouldRedirectToTLS(isTLS) { + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) + return + } + + if hh.handlePingRequest(slug, conn) { + return + } + + sshSession, err := hh.getSession(slug) + if err != nil { + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) + return + } + + hw := stream.New(conn, dstReader, conn.RemoteAddr()) + defer func(hw stream.HTTP) { + err = hw.Close() + if err != nil { + log.Printf("Error closing HTTP stream: %v", err) + } + }(hw) + hh.forwardRequest(hw, reqhf, sshSession) +} + +func (hh *httpHandler) closeConnection(conn net.Conn) { + err := conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + log.Printf("Error closing connection: %v", err) + } +} + +func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { + host := strings.Split(reqhf.Value("Host"), ".") + if len(host) < 1 { + return "", errors.New("invalid host") + } + return host[0], nil +} + +func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { + return !isTLS && hh.redirectTLS +} + +func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { + if slug != "ping" { + return false + } + + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) + if err != nil { + log.Println("Failed to write 200 OK:", err) + } + return true +} + +func (hh *httpHandler) getSession(slug string) (registry.Session, error) { + sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ + Id: slug, + Type: types.HTTP, + }) + if err != nil { + return nil, err + } + return sshSession, nil +} + +func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { + channel, err := hh.openForwardedChannel(hw, sshSession) + defer func() { + err = channel.Close() + if err != nil { + log.Printf("Error closing forwarded channel: %v", err) + } + }() + if err != nil { + log.Printf("Failed to establish channel: %v", err) + sshSession.Forwarder().WriteBadGatewayResponse(hw) + return + } + hh.setupMiddlewares(hw) + + if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil { + log.Printf("Failed to forward initial request: %v", err) + return + } + sshSession.Forwarder().HandleConnection(hw, channel) +} + +func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) { + payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) + + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + + resultChan := make(chan channelResult, 1) + + go func() { + channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + hh.cleanupUnusedChannel(channel, reqs) + } + }() + + select { + case result := <-resultChan: + if result.err != nil { + return nil, result.err + } + go ssh.DiscardRequests(result.reqs) + return result.channel, nil + case <-time.After(5 * time.Second): + return nil, errors.New("timeout opening forwarded-tcpip channel") + } +} + +func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { + if channel != nil { + if err := channel.Close(); err != nil { + log.Printf("Failed to close unused channel: %v", err) + } + go ssh.DiscardRequests(reqs) + } +} + +func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { + fingerprintMiddleware := middleware.NewTunnelFingerprint() + forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) + + hw.UseResponseMiddleware(fingerprintMiddleware) + hw.UseRequestMiddleware(forwardedForMiddleware) +} + +func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error { + hw.SetRequestHeader(initialRequest) + + if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { + return fmt.Errorf("error applying request middlewares: %w", err) + } + + if _, err := channel.Write(initialRequest.Finalize()); err != nil { + return fmt.Errorf("error writing to channel: %w", err) + } + + return nil +} diff --git a/internal/transport/https.go b/internal/transport/https.go new file mode 100644 index 0000000..104aa15 --- /dev/null +++ b/internal/transport/https.go @@ -0,0 +1,48 @@ +package transport + +import ( + "crypto/tls" + "errors" + "log" + "net" + "tunnel_pls/internal/registry" +) + +type https struct { + httpHandler *httpHandler + domain string + port string +} + +func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { + return &https{ + httpHandler: newHTTPHandler(sessionRegistry, redirectTLS), + domain: domain, + port: port, + } +} + +func (ht *https) Listen() (net.Listener, error) { + tlsConfig, err := NewTLSConfig(ht.domain) + if err != nil { + return nil, err + } + + return tls.Listen("tcp", ":"+ht.port, tlsConfig) + +} +func (ht *https) Serve(listener net.Listener) error { + log.Printf("HTTPS server is starting on port %s", ht.port) + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return err + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go ht.httpHandler.handler(conn, true) + } +} diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go new file mode 100644 index 0000000..99670d2 --- /dev/null +++ b/internal/transport/tcp.go @@ -0,0 +1,66 @@ +package transport + +import ( + "errors" + "fmt" + "io" + "log" + "net" + + "golang.org/x/crypto/ssh" +) + +type tcp struct { + port uint16 + forwarder forwarder +} + +type forwarder interface { + CreateForwardedTCPIPPayload(origin net.Addr) []byte + OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) + HandleConnection(dst io.ReadWriter, src ssh.Channel) +} + +func NewTCPServer(port uint16, forwarder forwarder) Transport { + return &tcp{ + port: port, + forwarder: forwarder, + } +} + +func (tt *tcp) Listen() (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port)) +} + +func (tt *tcp) Serve(listener net.Listener) error { + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + log.Printf("Error accepting connection: %v", err) + continue + } + go tt.handleTcp(conn) + } +} + +func (tt *tcp) handleTcp(conn net.Conn) { + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Failed to close connection: %v", err) + } + }() + payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr()) + channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + + return + } + + go ssh.DiscardRequests(reqs) + tt.forwarder.HandleConnection(conn, channel) +} diff --git a/server/tls.go b/internal/transport/tls.go similarity index 99% rename from server/tls.go rename to internal/transport/tls.go index fc67733..0893b85 100644 --- a/server/tls.go +++ b/internal/transport/tls.go @@ -1,4 +1,4 @@ -package server +package transport import ( "context" diff --git a/internal/transport/transport.go b/internal/transport/transport.go new file mode 100644 index 0000000..ca27061 --- /dev/null +++ b/internal/transport/transport.go @@ -0,0 +1,10 @@ +package transport + +import ( + "net" +) + +type Transport interface { + Listen() (net.Listener, error) + Serve(listener net.Listener) error +} diff --git a/version/version.go b/internal/version/version.go similarity index 100% rename from version/version.go rename to internal/version/version.go diff --git a/main.go b/main.go index 2303718..6510932 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log" + "net" "net/http" _ "net/http/pprof" "os" @@ -16,9 +17,10 @@ import ( "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/key" "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" + "tunnel_pls/internal/version" "tunnel_pls/server" - "tunnel_pls/session" - "tunnel_pls/version" "golang.org/x/crypto/ssh" ) @@ -76,7 +78,7 @@ func main() { } sshConfig.AddHostKey(private) - sessionRegistry := session.NewRegistry() + sessionRegistry := registry.NewRegistry() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -131,7 +133,7 @@ func main() { log.Fatalf("Failed to parse end port: %s", err) } - if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil { + if err = portManager.AddRange(uint16(start), uint16(end)); err != nil { log.Fatalf("Failed to add port range: %s", err) } log.Printf("PortRegistry range configured: %d-%d", start, end) @@ -139,9 +141,51 @@ func main() { log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) } } + + tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true" + redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" + + go func() { + httpPort := config.Getenv("HTTP_PORT", "8080") + + var httpListener net.Listener + httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS) + httpListener, err = httpserver.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return + } + err = httpserver.Serve(httpListener) + if err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + return + } + }() + + if tlsEnabled { + go func() { + httpsPort := config.Getenv("HTTPS_PORT", "8443") + domain := config.Getenv("DOMAIN", "localhost") + + var httpListener net.Listener + httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS) + httpListener, err = httpserver.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return + } + err = httpserver.Serve(httpListener) + if err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + return + } + }() + } + var app server.Server go func() { - app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager) + sshPort := config.Getenv("PORT", "2200") + app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return diff --git a/server/http.go b/server/http.go deleted file mode 100644 index a36b6b1..0000000 --- a/server/http.go +++ /dev/null @@ -1,276 +0,0 @@ -package server - -import ( - "bufio" - "crypto/tls" - "errors" - "fmt" - "log" - "net" - "net/http" - "strings" - "time" - "tunnel_pls/internal/config" - "tunnel_pls/internal/httpheader" - "tunnel_pls/session" - "tunnel_pls/types" - - "golang.org/x/crypto/ssh" -) - -type HTTPServer interface { - ListenAndServe() error - ListenAndServeTLS() error -} -type httpServer struct { - sessionRegistry session.Registry - redirectTLS bool -} - -func NewHTTPServer(sessionRegistry session.Registry, redirectTLS bool) HTTPServer { - return &httpServer{ - sessionRegistry: sessionRegistry, - redirectTLS: redirectTLS, - } -} - -func (hs *httpServer) ListenAndServe() error { - httpPort := config.Getenv("HTTP_PORT", "8080") - listener, err := net.Listen("tcp", ":"+httpPort) - if err != nil { - return errors.New("Error listening: " + err.Error()) - } - go func() { - for { - var conn net.Conn - conn, err = listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handler(conn, false) - } - }() - return nil -} - -func (hs *httpServer) ListenAndServeTLS() error { - domain := config.Getenv("DOMAIN", "localhost") - httpsPort := config.Getenv("HTTPS_PORT", "8443") - - tlsConfig, err := NewTLSConfig(domain) - if err != nil { - return fmt.Errorf("failed to initialize TLS config: %w", err) - } - - ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig) - if err != nil { - return err - } - - go func() { - for { - var conn net.Conn - conn, err = ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - log.Println("https server closed") - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handler(conn, true) - } - }() - return nil -} - -func (hs *httpServer) redirect(conn net.Conn, status int, location string) error { - _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + - fmt.Sprintf("Location: %s", location) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - return err - } - return nil -} - -func (hs *httpServer) badRequest(conn net.Conn) error { - if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { - return err - } - return nil -} - -func (hs *httpServer) handler(conn net.Conn, isTLS bool) { - defer hs.closeConnection(conn) - - dstReader := bufio.NewReader(conn) - reqhf, err := httpheader.NewRequestHeader(dstReader) - if err != nil { - log.Printf("Error creating request header: %v", err) - return - } - - slug, err := hs.extractSlug(reqhf) - if err != nil { - _ = hs.badRequest(conn) - return - } - - if hs.shouldRedirectToTLS(isTLS) { - _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) - return - } - - if hs.handlePingRequest(slug, conn) { - return - } - - sshSession, err := hs.getSession(slug) - if err != nil { - _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) - return - } - - hw := NewHTTPWriter(conn, dstReader, conn.RemoteAddr()) - hs.forwardRequest(hw, reqhf, sshSession) -} - -func (hs *httpServer) closeConnection(conn net.Conn) { - err := conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - log.Printf("Error closing connection: %v", err) - } -} - -func (hs *httpServer) extractSlug(reqhf httpheader.RequestHeader) (string, error) { - host := strings.Split(reqhf.Value("Host"), ".") - if len(host) < 1 { - return "", errors.New("invalid host") - } - return host[0], nil -} - -func (hs *httpServer) shouldRedirectToTLS(isTLS bool) bool { - return !isTLS && hs.redirectTLS -} - -func (hs *httpServer) handlePingRequest(slug string, conn net.Conn) bool { - if slug != "ping" { - return false - } - - _, err := conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - } - return true -} - -func (hs *httpServer) getSession(slug string) (session.Session, error) { - sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.HTTP, - }) - if err != nil { - return nil, err - } - return sshSession, nil -} - -func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, sshSession session.Session) { - channel, err := hs.openForwardedChannel(hw, sshSession) - if err != nil { - log.Printf("Failed to establish channel: %v", err) - sshSession.Forwarder().WriteBadGatewayResponse(hw) - return - } - - hs.setupMiddlewares(hw) - - if err := hs.sendInitialRequest(hw, initialRequest, channel); err != nil { - log.Printf("Failed to forward initial request: %v", err) - return - } - - sshSession.Forwarder().HandleConnection(hw, channel, hw.RemoteAddr()) -} - -func (hs *httpServer) openForwardedChannel(hw HTTPWriter, sshSession session.Session) (ssh.Channel, error) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - select { - case resultChan <- channelResult{channel, reqs, err}: - default: - hs.cleanupUnusedChannel(channel, reqs) - } - }() - - select { - case result := <-resultChan: - if result.err != nil { - return nil, result.err - } - go ssh.DiscardRequests(result.reqs) - return result.channel, nil - case <-time.After(5 * time.Second): - return nil, errors.New("timeout opening forwarded-tcpip channel") - } -} - -func (hs *httpServer) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { - if channel != nil { - if err := channel.Close(); err != nil { - log.Printf("Failed to close unused channel: %v", err) - } - go ssh.DiscardRequests(reqs) - } -} - -func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { - fingerprintMiddleware := NewTunnelFingerprint() - forwardedForMiddleware := NewForwardedFor(hw.RemoteAddr()) - - hw.UseResponseMiddleware(fingerprintMiddleware) - hw.UseRequestMiddleware(forwardedForMiddleware) -} - -func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, channel ssh.Channel) error { - hw.SetRequestHeader(initialRequest) - - if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { - return fmt.Errorf("error applying request middlewares: %w", err) - } - - if _, err := channel.Write(initialRequest.Finalize()); err != nil { - return fmt.Errorf("error writing to channel: %w", err) - } - - return nil -} diff --git a/server/httpwritter.go b/server/httpwritter.go deleted file mode 100644 index 9d52f24..0000000 --- a/server/httpwritter.go +++ /dev/null @@ -1,254 +0,0 @@ -package server - -import ( - "bytes" - "io" - "log" - "net" - "regexp" - "tunnel_pls/internal/httpheader" -) - -type HTTPWriter interface { - io.ReadWriteCloser - CloseWrite() error - RemoteAddr() net.Addr - UseResponseMiddleware(mw ResponseMiddleware) - UseRequestMiddleware(mw RequestMiddleware) - SetRequestHeader(header httpheader.RequestHeader) - RequestMiddlewares() []RequestMiddleware - ResponseMiddlewares() []ResponseMiddleware - ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error - ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error -} - -type httpWriter struct { - remoteAddr net.Addr - writer io.Writer - reader io.Reader - headerBuf []byte - buf []byte - respHeader httpheader.ResponseHeader - reqHeader httpheader.RequestHeader - respMW []ResponseMiddleware - reqMW []RequestMiddleware -} - -var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} -var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) -var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) - -func (hw *httpWriter) RemoteAddr() net.Addr { - return hw.remoteAddr -} - -func (hw *httpWriter) UseResponseMiddleware(mw ResponseMiddleware) { - hw.respMW = append(hw.respMW, mw) -} - -func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { - hw.reqMW = append(hw.reqMW, mw) -} - -func (hw *httpWriter) SetRequestHeader(header httpheader.RequestHeader) { - hw.reqHeader = header -} - -func (hw *httpWriter) RequestMiddlewares() []RequestMiddleware { - return hw.reqMW -} - -func (hw *httpWriter) ResponseMiddlewares() []ResponseMiddleware { - return hw.respMW -} -func (hw *httpWriter) Close() error { - return hw.writer.(io.Closer).Close() -} - -func (hw *httpWriter) CloseWrite() error { - if closer, ok := hw.writer.(interface{ CloseWrite() error }); ok { - return closer.CloseWrite() - } - return hw.Close() -} - -func (hw *httpWriter) Read(p []byte) (int, error) { - tmp := make([]byte, len(p)) - read, err := hw.reader.Read(tmp) - if read == 0 && err != nil { - return 0, err - } - - tmp = tmp[:read] - - headerEndIdx := bytes.Index(tmp, DELIMITER) - if headerEndIdx == -1 { - return hw.handleNoDelimiter(p, tmp, err) - } - - header, body := hw.splitHeaderAndBody(tmp, headerEndIdx) - - if !isHTTPHeader(header) { - copy(p, tmp) - return read, nil - } - - return hw.processHTTPRequest(p, header, body) -} - -func (hw *httpWriter) handleNoDelimiter(p, tmp []byte, err error) (int, error) { - copy(p, tmp) - return len(tmp), err -} - -func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) { - header := data[:delimiterIdx+len(DELIMITER)] - body := data[delimiterIdx+len(DELIMITER):] - return header, body -} - -func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { - reqhf, err := httpheader.NewRequestHeader(header) - if err != nil { - return 0, err - } - - if err = hw.ApplyRequestMiddlewares(reqhf); err != nil { - return 0, err - } - - hw.reqHeader = reqhf - combined := append(reqhf.Finalize(), body...) - return copy(p, combined), nil -} - -func (hw *httpWriter) ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error { - for _, m := range hw.RequestMiddlewares() { - if err := m.HandleRequest(reqhf); err != nil { - log.Printf("Error when applying request middleware: %v", err) - return err - } - } - return nil -} - -func (hw *httpWriter) Write(p []byte) (int, error) { - if hw.shouldBypassBuffering(p) { - hw.respHeader = nil - } - - if hw.respHeader != nil { - return hw.writer.Write(p) - } - - hw.buf = append(hw.buf, p...) - - headerEndIdx := bytes.Index(hw.buf, DELIMITER) - if headerEndIdx == -1 { - return len(p), nil - } - - return hw.processBufferedResponse(p, headerEndIdx) -} - -func (hw *httpWriter) shouldBypassBuffering(p []byte) bool { - return hw.respHeader != nil && len(hw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" -} - -func (hw *httpWriter) processBufferedResponse(p []byte, delimiterIdx int) (int, error) { - header, body := hw.splitHeaderAndBody(hw.buf, delimiterIdx) - - if !isHTTPHeader(header) { - return hw.writeRawBuffer() - } - - if err := hw.processHTTPResponse(header, body); err != nil { - return 0, err - } - - hw.buf = nil - return len(p), nil -} - -func (hw *httpWriter) writeRawBuffer() (int, error) { - _, err := hw.writer.Write(hw.buf) - length := len(hw.buf) - hw.buf = nil - if err != nil { - return 0, err - } - return length, nil -} - -func (hw *httpWriter) processHTTPResponse(header, body []byte) error { - resphf, err := httpheader.NewResponseHeader(header) - if err != nil { - return err - } - - if err = hw.ApplyResponseMiddlewares(resphf, body); err != nil { - return err - } - - hw.respHeader = resphf - finalHeader := resphf.Finalize() - - if err = hw.writeHeaderAndBody(finalHeader, body); err != nil { - return err - } - - return nil -} - -func (hw *httpWriter) ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error { - for _, m := range hw.ResponseMiddlewares() { - if err := m.HandleResponse(resphf, body); err != nil { - log.Printf("Cannot apply middleware: %s\n", err) - return err - } - } - return nil -} - -func (hw *httpWriter) writeHeaderAndBody(header, body []byte) error { - if _, err := hw.writer.Write(header); err != nil { - return err - } - - if len(body) > 0 { - if _, err := hw.writer.Write(body); err != nil { - return err - } - } - - return nil -} - -func NewHTTPWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { - return &httpWriter{ - remoteAddr: remoteAddr, - writer: writer, - reader: reader, - buf: make([]byte, 0, 4096), - } -} - -func isHTTPHeader(buf []byte) bool { - lines := bytes.Split(buf, []byte("\r\n")) - - startLine := string(lines[0]) - if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { - return false - } - - for _, line := range lines[1:] { - if len(line) == 0 { - break - } - colonIdx := bytes.IndexByte(line, ':') - if colonIdx <= 0 { - return false - } - } - return true -} diff --git a/server/middleware.go b/server/middleware.go deleted file mode 100644 index 6f50c4c..0000000 --- a/server/middleware.go +++ /dev/null @@ -1,42 +0,0 @@ -package server - -import ( - "net" - "tunnel_pls/internal/httpheader" -) - -type RequestMiddleware interface { - HandleRequest(header httpheader.RequestHeader) error -} - -type ResponseMiddleware interface { - HandleResponse(header httpheader.ResponseHeader, body []byte) error -} - -type TunnelFingerprint struct{} - -func NewTunnelFingerprint() *TunnelFingerprint { - return &TunnelFingerprint{} -} - -func (h *TunnelFingerprint) HandleResponse(header httpheader.ResponseHeader, body []byte) error { - header.Set("Server", "Tunnel Please") - return nil -} - -type ForwardedFor struct { - addr net.Addr -} - -func NewForwardedFor(addr net.Addr) *ForwardedFor { - return &ForwardedFor{addr: addr} -} - -func (ff *ForwardedFor) HandleRequest(header httpheader.RequestHeader) error { - host, _, err := net.SplitHostPort(ff.addr.String()) - if err != nil { - return err - } - header.Set("X-Forwarded-For", host) - return nil -} diff --git a/server/server.go b/server/server.go index 868b9e6..185d051 100644 --- a/server/server.go +++ b/server/server.go @@ -7,9 +7,9 @@ import ( "log" "net" "time" - "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" "tunnel_pls/session" "golang.org/x/crypto/ssh" @@ -20,38 +20,23 @@ type Server interface { Close() error } type server struct { - listener net.Listener + sshPort string + sshListener net.Listener config *ssh.ServerConfig grpcClient client.Client - sessionRegistry session.Registry - portRegistry port.Registry + sessionRegistry registry.Registry + portRegistry port.Port } -func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) { - listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) +func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { + listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) if err != nil { - log.Fatalf("failed to listen on port 2200: %v", err) return nil, err } - redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" - - HttpServer := NewHTTPServer(sessionRegistry, redirectTLS) - err = HttpServer.ListenAndServe() - if err != nil { - log.Fatalf("failed to start http server: %v", err) - return nil, err - } - - if config.Getenv("TLS_ENABLED", "false") == "true" { - err = HttpServer.ListenAndServeTLS() - if err != nil { - log.Fatalf("failed to start https server: %v", err) - return nil, err - } - } return &server{ - listener: listener, + sshPort: sshPort, + sshListener: listener, config: sshConfig, grpcClient: grpcClient, sessionRegistry: sessionRegistry, @@ -60,9 +45,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie } func (s *server) Start() { - log.Println("SSH server is starting on port 2200...") + log.Printf("SSH server is starting on port %s", s.sshPort) for { - conn, err := s.listener.Accept() + conn, err := s.sshListener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { log.Println("listener closed, stopping server") @@ -77,7 +62,7 @@ func (s *server) Start() { } func (s *server) Close() error { - return s.listener.Close() + return s.sshListener.Close() } func (s *server) handleConnection(conn net.Conn) { diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index cac6691..ff2abde 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -56,14 +56,14 @@ type Forwarder interface { Listener() net.Listener TunnelType() types.TunnelType ForwardedPort() uint16 - HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) + HandleConnection(dst io.ReadWriter, src ssh.Channel) CreateForwardedTCPIPPayload(origin net.Addr) []byte + OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) WriteBadGatewayResponse(dst io.Writer) - AcceptTCPConnections() Close() error } -func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { +func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { type channelResult struct { channel ssh.Channel reqs <-chan *ssh.Request @@ -95,38 +95,6 @@ func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *s } } -func (f *forwarder) handleIncomingConnection(conn net.Conn) { - payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - - channel, reqs, err := f.openForwardedChannel(payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - err = conn.Close() - if err != nil { - log.Printf("Failed to close connection: %v", err) - } - return - } - - go ssh.DiscardRequests(reqs) - go f.HandleConnection(conn, channel, conn.RemoteAddr()) -} - -func (f *forwarder) AcceptTCPConnections() { - for { - conn, err := f.Listener().Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go f.handleIncomingConnection(conn) - } -} - func closeWriter(w io.Writer) error { if cw, ok := w.(interface{ CloseWrite() error }); ok { return cw.CloseWrite() @@ -145,12 +113,12 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) } if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) { - errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err)) + errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err)) } return errors.Join(errs...) } -func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { +func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { defer func() { _, err := io.Copy(io.Discard, src) if err != nil { @@ -158,8 +126,6 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA } }() - log.Printf("Handling new forwarded connection from %s", remoteAddr) - var wg sync.WaitGroup wg.Add(2) diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 8d134a2..7a2fcaf 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -31,11 +31,11 @@ type lifecycle struct { slug slug.Slug startedAt time.Time sessionRegistry SessionRegistry - portRegistry portUtil.Registry + portRegistry portUtil.Port user string } -func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle { +func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle { return &lifecycle{ status: types.INITIALIZING, conn: conn, @@ -51,7 +51,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti type Lifecycle interface { Connection() ssh.Conn - PortRegistry() portUtil.Registry + PortRegistry() portUtil.Port User() string SetChannel(channel ssh.Channel) SetStatus(status types.Status) @@ -60,7 +60,7 @@ type Lifecycle interface { Close() error } -func (l *lifecycle) PortRegistry() portUtil.Registry { +func (l *lifecycle) PortRegistry() portUtil.Port { return l.portRegistry } @@ -113,7 +113,7 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TCP { - if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { + if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { firstErr = err } } diff --git a/session/session.go b/session/session.go index 4118550..d113084 100644 --- a/session/session.go +++ b/session/session.go @@ -12,6 +12,8 @@ import ( "tunnel_pls/internal/config" portUtil "tunnel_pls/internal/port" "tunnel_pls/internal/random" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" @@ -21,14 +23,6 @@ import ( "golang.org/x/crypto/ssh" ) -type Detail struct { - ForwardingType string `json:"forwarding_type,omitempty"` - Slug string `json:"slug,omitempty"` - UserID string `json:"user_id,omitempty"` - Active bool `json:"active,omitempty"` - StartedAt time.Time `json:"started_at,omitempty"` -} - type Session interface { HandleGlobalRequest(ch <-chan *ssh.Request) error HandleTCPIPForward(req *ssh.Request) error @@ -38,7 +32,7 @@ type Session interface { Interaction() interaction.Interaction Forwarder() forwarder.Forwarder Slug() slug.Slug - Detail() *Detail + Detail() *types.Detail Start() error } @@ -49,12 +43,12 @@ type session struct { interaction interaction.Interaction forwarder forwarder.Forwarder slug slug.Slug - registry Registry + registry registry.Registry } var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session { +func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { slugManager := slug.New() forwarderManager := forwarder.New(slugManager, conn) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) @@ -87,7 +81,7 @@ func (s *session) Slug() slug.Slug { return s.slug } -func (s *session) Detail() *Detail { +func (s *session) Detail() *types.Detail { tunnelTypeMap := map[types.TunnelType]string{ types.HTTP: "HTTP", types.TCP: "TCP", @@ -97,7 +91,7 @@ func (s *session) Detail() *Detail { tunnelType = "UNKNOWN" } - return &Detail{ + return &types.Detail{ ForwardingType: tunnelType, Slug: s.slug.String(), UserID: s.lifecycle.User(), @@ -271,7 +265,7 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, } if port == 0 { - unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() + unassigned, ok := s.lifecycle.PortRegistry().Unassigned() if !ok { return "", 0, fmt.Errorf("no available port") } @@ -328,7 +322,6 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen if listener != nil { s.forwarder.SetListener(listener) - go s.forwarder.AcceptTCPConnections() } return nil @@ -346,7 +339,6 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error { case 80, 443: return s.HandleHTTPForward(req, port) default: - return s.HandleTCPForward(req, address, port) } } @@ -369,11 +361,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { } func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { - if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed { + if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) + tcpServer := transport.NewTCPServer(portToBind, s.forwarder) + listener, err := tcpServer.Listen() if err != nil { return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } @@ -387,6 +380,14 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin if err != nil { return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } + + go func() { + err = tcpServer.Serve(listener) + if err != nil { + log.Printf("Failed serving tcp server: %s\n", err) + } + }() + return nil } diff --git a/types/types.go b/types/types.go index 148cd2b..b91dffb 100644 --- a/types/types.go +++ b/types/types.go @@ -1,5 +1,7 @@ package types +import "time" + type Status int const ( @@ -27,6 +29,14 @@ type SessionKey struct { Type TunnelType } +type Detail struct { + ForwardingType string `json:"forwarding_type,omitempty"` + Slug string `json:"slug,omitempty"` + UserID string `json:"user_id,omitempty"` + Active bool `json:"active,omitempty"` + StartedAt time.Time `json:"started_at,omitempty"` +} + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" +