diff --git a/.gitea/workflows/sonarqube.yml b/.gitea/workflows/sonarqube.yml new file mode 100644 index 0000000..9c672ac --- /dev/null +++ b/.gitea/workflows/sonarqube.yml @@ -0,0 +1,20 @@ +on: + push: + pull_request: + types: [opened, synchronize, reopened] + +name: SonarQube Scan +jobs: + sonarqube: + name: SonarQube Trigger + runs-on: ubuntu-latest + steps: + - name: Checking out + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: SonarQube Scan + uses: SonarSource/sonarqube-scan-action@v7.0.0 + env: + SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }} + SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} diff --git a/internal/config/config.go b/internal/config/config.go index 45f1cc5..62e1aca 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,33 +1,63 @@ package config -import ( - "os" - "strconv" +import "tunnel_pls/types" - "github.com/joho/godotenv" -) +type Config interface { + Domain() string + SSHPort() string -func Load() error { - if _, err := os.Stat(".env"); err == nil { - return godotenv.Load(".env") - } - return nil + HTTPPort() string + HTTPSPort() string + + TLSEnabled() bool + TLSRedirect() bool + + ACMEEmail() string + CFAPIToken() string + ACMEStaging() bool + + AllowedPortsStart() uint16 + AllowedPortsEnd() uint16 + + BufferSize() int + + PprofEnabled() bool + PprofPort() string + + Mode() types.ServerMode + GRPCAddress() string + GRPCPort() string + NodeToken() string } -func Getenv(key, defaultValue string) string { - val := os.Getenv(key) - if val == "" { - val = defaultValue +func MustLoad() (Config, error) { + if err := loadEnvFile(); err != nil { + return nil, err } - return val + cfg, err := parse() + if err != nil { + return nil, err + } + + return cfg, nil } -func GetBufferSize() int { - sizeStr := Getenv("BUFFER_SIZE", "32768") - size, err := strconv.Atoi(sizeStr) - if err != nil || size < 4096 || size > 1048576 { - return 32768 - } - return size -} +func (c *config) Domain() string { return c.domain } +func (c *config) SSHPort() string { return c.sshPort } +func (c *config) HTTPPort() string { return c.httpPort } +func (c *config) HTTPSPort() string { return c.httpsPort } +func (c *config) TLSEnabled() bool { return c.tlsEnabled } +func (c *config) TLSRedirect() bool { return c.tlsRedirect } +func (c *config) ACMEEmail() string { return c.acmeEmail } +func (c *config) CFAPIToken() string { return c.cfAPIToken } +func (c *config) ACMEStaging() bool { return c.acmeStaging } +func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart } +func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd } +func (c *config) BufferSize() int { return c.bufferSize } +func (c *config) PprofEnabled() bool { return c.pprofEnabled } +func (c *config) PprofPort() string { return c.pprofPort } +func (c *config) Mode() types.ServerMode { return c.mode } +func (c *config) GRPCAddress() string { return c.grpcAddress } +func (c *config) GRPCPort() string { return c.grpcPort } +func (c *config) NodeToken() string { return c.nodeToken } diff --git a/internal/config/loader.go b/internal/config/loader.go new file mode 100644 index 0000000..cde9fd0 --- /dev/null +++ b/internal/config/loader.go @@ -0,0 +1,170 @@ +package config + +import ( + "fmt" + "log" + "os" + "strconv" + "strings" + "tunnel_pls/types" + + "github.com/joho/godotenv" +) + +type config struct { + domain string + sshPort string + + httpPort string + httpsPort string + + tlsEnabled bool + tlsRedirect bool + + acmeEmail string + cfAPIToken string + acmeStaging bool + + allowedPortsStart uint16 + allowedPortsEnd uint16 + + bufferSize int + + pprofEnabled bool + pprofPort string + + mode types.ServerMode + grpcAddress string + grpcPort string + nodeToken string +} + +func parse() (*config, error) { + mode, err := parseMode() + if err != nil { + return nil, err + } + + domain := getenv("DOMAIN", "localhost") + sshPort := getenv("PORT", "2200") + + httpPort := getenv("HTTP_PORT", "8080") + httpsPort := getenv("HTTPS_PORT", "8443") + + tlsEnabled := getenvBool("TLS_ENABLED", false) + tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false) + + acmeEmail := getenv("ACME_EMAIL", "admin@"+domain) + acmeStaging := getenvBool("ACME_STAGING", false) + + cfToken := getenv("CF_API_TOKEN", "") + if tlsEnabled && cfToken == "" { + return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled") + } + + start, end, err := parseAllowedPorts() + if err != nil { + return nil, err + } + + bufferSize := parseBufferSize() + + pprofEnabled := getenvBool("PPROF_ENABLED", false) + pprofPort := getenv("PPROF_PORT", "6060") + + grpcHost := getenv("GRPC_ADDRESS", "localhost") + grpcPort := getenv("GRPC_PORT", "8080") + + nodeToken := getenv("NODE_TOKEN", "") + if mode == types.ServerModeNODE && nodeToken == "" { + return nil, fmt.Errorf("NODE_TOKEN is required in node mode") + } + + return &config{ + domain: domain, + sshPort: sshPort, + httpPort: httpPort, + httpsPort: httpsPort, + tlsEnabled: tlsEnabled, + tlsRedirect: tlsRedirect, + acmeEmail: acmeEmail, + cfAPIToken: cfToken, + acmeStaging: acmeStaging, + allowedPortsStart: start, + allowedPortsEnd: end, + bufferSize: bufferSize, + pprofEnabled: pprofEnabled, + pprofPort: pprofPort, + mode: mode, + grpcAddress: grpcHost, + grpcPort: grpcPort, + nodeToken: nodeToken, + }, nil +} + +func loadEnvFile() error { + if _, err := os.Stat(".env"); err == nil { + return godotenv.Load(".env") + } + return nil +} + +func parseMode() (types.ServerMode, error) { + switch strings.ToLower(getenv("MODE", "standalone")) { + case "standalone": + return types.ServerModeSTANDALONE, nil + case "node": + return types.ServerModeNODE, nil + default: + return 0, fmt.Errorf("invalid MODE value") + } +} + +func parseAllowedPorts() (uint16, uint16, error) { + raw := getenv("ALLOWED_PORTS", "") + if raw == "" { + return 0, 0, nil + } + + parts := strings.Split(raw, "-") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format") + } + + start, err := strconv.ParseUint(parts[0], 10, 16) + if err != nil { + return 0, 0, err + } + + end, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return 0, 0, err + } + + return uint16(start), uint16(end), nil +} + +func parseBufferSize() int { + raw := getenv("BUFFER_SIZE", "32768") + size, err := strconv.Atoi(raw) + if err != nil || size < 4096 || size > 1048576 { + log.Println("Invalid BUFFER_SIZE, falling back to 4096") + return 4096 + } + return size +} + +func getenv(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +func getenvBool(key string, def bool) bool { + val := os.Getenv(key) + if val == "" { + return def + } + return val == "true" +} diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index ffcc7d9..f2e0a1e 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -8,10 +8,9 @@ import ( "log" "time" "tunnel_pls/internal/config" + "tunnel_pls/internal/registry" "tunnel_pls/types" - "tunnel_pls/session" - proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -30,15 +29,16 @@ type Client interface { CheckServerHealth(ctx context.Context) error } type client struct { + config config.Config conn *grpc.ClientConn address string - sessionRegistry session.Registry + sessionRegistry registry.Registry eventService proto.EventServiceClient authorizeConnectionService proto.UserServiceClient closing bool } -func New(address string, sessionRegistry session.Registry) (Client, error) { +func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -67,6 +67,7 @@ func New(address string, sessionRegistry session.Registry) (Client, error) { authorizeConnectionService := proto.NewUserServiceClient(conn) return &client{ + config: config, conn: conn, address: address, sessionRegistry: sessionRegistry, @@ -193,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, oldSlug := slugEvent.GetOld() newSlug := slugEvent.GetNew() - userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP}) + userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}) if err != nil { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, @@ -203,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, }, "slug change failure response") } - if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil { + if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, Payload: &proto.Node_SlugEventResponse{ @@ -228,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node for _, ses := range sessions { detail := ses.Detail() details = append(details, &proto.Detail{ - Node: config.Getenv("DOMAIN", "localhost"), + Node: c.config.Domain(), ForwardingType: detail.ForwardingType, Slug: detail.Slug, UserId: detail.UserID, @@ -300,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { switch t { case proto.TunnelType_HTTP: - return types.HTTP, nil + return types.TunnelTypeHTTP, nil case proto.TunnelType_TCP: - return types.TCP, nil + return types.TunnelTypeTCP, nil default: - return types.UNKNOWN, fmt.Errorf("unknown tunnel type received") + return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received") } } diff --git a/internal/http/header/header.go b/internal/http/header/header.go new file mode 100644 index 0000000..605f9ec --- /dev/null +++ b/internal/http/header/header.go @@ -0,0 +1,30 @@ +package header + +type ResponseHeader interface { + Value(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte +} + +type responseHeader struct { + startLine []byte + headers map[string]string +} + +type RequestHeader interface { + Value(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte + Method() string + Path() string + Version() string +} +type requestHeader struct { + method string + path string + version string + startLine []byte + headers map[string]string +} diff --git a/internal/http/header/parser.go b/internal/http/header/parser.go new file mode 100644 index 0000000..861c49e --- /dev/null +++ b/internal/http/header/parser.go @@ -0,0 +1,148 @@ +package header + +import ( + "bufio" + "bytes" + "fmt" +) + +func setRemainingHeaders(remaining []byte, header interface { + Set(key string, value string) +}) { + for len(remaining) > 0 { + lineEnd := bytes.Index(remaining, []byte("\r\n")) + if lineEnd == -1 { + lineEnd = len(remaining) + } + + line := remaining[:lineEnd] + + if len(line) == 0 { + break + } + + colonIdx := bytes.IndexByte(line, ':') + if colonIdx != -1 { + key := bytes.TrimSpace(line[:colonIdx]) + value := bytes.TrimSpace(line[colonIdx+1:]) + header.Set(string(key), string(value)) + } + + if lineEnd == len(remaining) { + break + } + + remaining = remaining[lineEnd+2:] + } +} + +func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) { + header := &requestHeader{ + headers: make(map[string]string, 16), + } + + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no CRLF found in start line") + } + + startLine := headerData[:lineEnd] + header.startLine = startLine + var err error + header.method, header.path, header.version, err = parseStartLine(startLine) + if err != nil { + return nil, err + } + + remaining := headerData[lineEnd+2:] + + setRemainingHeaders(remaining, header) + + return header, nil +} + +func parseStartLine(startLine []byte) (method, path, version string, err error) { + firstSpace := bytes.IndexByte(startLine, ' ') + if firstSpace == -1 { + return "", "", "", fmt.Errorf("invalid start line: missing method") + } + + secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ') + if secondSpace == -1 { + return "", "", "", fmt.Errorf("invalid start line: missing version") + } + secondSpace += firstSpace + 1 + + method = string(startLine[:firstSpace]) + path = string(startLine[firstSpace+1 : secondSpace]) + version = string(startLine[secondSpace+1:]) + + return method, path, version, nil +} + +func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) { + header := &requestHeader{ + headers: make(map[string]string, 16), + } + + startLineBytes, err := br.ReadSlice('\n') + if err != nil { + return nil, err + } + + startLineBytes = bytes.TrimRight(startLineBytes, "\r\n") + header.startLine = make([]byte, len(startLineBytes)) + copy(header.startLine, startLineBytes) + + header.method, header.path, header.version, err = parseStartLine(header.startLine) + if err != nil { + return nil, err + } + + for { + lineBytes, err := br.ReadSlice('\n') + if err != nil { + return nil, err + } + + lineBytes = bytes.TrimRight(lineBytes, "\r\n") + + if len(lineBytes) == 0 { + break + } + + colonIdx := bytes.IndexByte(lineBytes, ':') + if colonIdx == -1 { + continue + } + + key := bytes.TrimSpace(lineBytes[:colonIdx]) + value := bytes.TrimSpace(lineBytes[colonIdx+1:]) + + header.headers[string(key)] = string(value) + } + + return header, nil +} + +func finalize(startLine []byte, headers map[string]string) []byte { + size := len(startLine) + 2 + for key, val := range headers { + size += len(key) + 2 + len(val) + 2 + } + size += 2 + + buf := make([]byte, 0, size) + buf = append(buf, startLine...) + buf = append(buf, '\r', '\n') + + for key, val := range headers { + buf = append(buf, key...) + buf = append(buf, ':', ' ') + buf = append(buf, val...) + buf = append(buf, '\r', '\n') + } + + buf = append(buf, '\r', '\n') + return buf +} diff --git a/internal/http/header/request.go b/internal/http/header/request.go new file mode 100644 index 0000000..1fbe57a --- /dev/null +++ b/internal/http/header/request.go @@ -0,0 +1,49 @@ +package header + +import ( + "bufio" + "fmt" +) + +func NewRequest(r interface{}) (RequestHeader, error) { + switch v := r.(type) { + case []byte: + return parseHeadersFromBytes(v) + case *bufio.Reader: + return parseHeadersFromReader(v) + default: + return nil, fmt.Errorf("unsupported type: %T", r) + } +} + +func (req *requestHeader) Value(key string) string { + val, ok := req.headers[key] + if !ok { + return "" + } + return val +} + +func (req *requestHeader) Set(key string, value string) { + req.headers[key] = value +} + +func (req *requestHeader) Remove(key string) { + delete(req.headers, key) +} + +func (req *requestHeader) Method() string { + return req.method +} + +func (req *requestHeader) Path() string { + return req.path +} + +func (req *requestHeader) Version() string { + return req.version +} + +func (req *requestHeader) Finalize() []byte { + return finalize(req.startLine, req.headers) +} diff --git a/internal/http/header/response.go b/internal/http/header/response.go new file mode 100644 index 0000000..b6305d4 --- /dev/null +++ b/internal/http/header/response.go @@ -0,0 +1,40 @@ +package header + +import ( + "bytes" + "fmt" +) + +func NewResponse(headerData []byte) (ResponseHeader, error) { + header := &responseHeader{ + startLine: nil, + headers: make(map[string]string, 16), + } + + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid response: no CRLF found in start line") + } + + header.startLine = headerData[:lineEnd] + remaining := headerData[lineEnd+2:] + setRemainingHeaders(remaining, header) + + return header, nil +} + +func (resp *responseHeader) Value(key string) string { + return resp.headers[key] +} + +func (resp *responseHeader) Set(key string, value string) { + resp.headers[key] = value +} + +func (resp *responseHeader) Remove(key string) { + delete(resp.headers, key) +} + +func (resp *responseHeader) Finalize() []byte { + return finalize(resp.startLine, resp.headers) +} diff --git a/internal/http/stream/parser.go b/internal/http/stream/parser.go new file mode 100644 index 0000000..b1d8277 --- /dev/null +++ b/internal/http/stream/parser.go @@ -0,0 +1,29 @@ +package stream + +import "bytes" + +func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) { + headerByte := data[:delimiterIdx+len(DELIMITER)] + body := data[delimiterIdx+len(DELIMITER):] + return headerByte, body +} + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + colonIdx := bytes.IndexByte(line, ':') + if colonIdx <= 0 { + return false + } + } + return true +} diff --git a/internal/http/stream/reader.go b/internal/http/stream/reader.go new file mode 100644 index 0000000..f7c99d4 --- /dev/null +++ b/internal/http/stream/reader.go @@ -0,0 +1,50 @@ +package stream + +import ( + "bytes" + "tunnel_pls/internal/http/header" +) + +func (hs *http) Read(p []byte) (int, error) { + tmp := make([]byte, len(p)) + read, err := hs.reader.Read(tmp) + if read == 0 && err != nil { + return 0, err + } + + tmp = tmp[:read] + + headerEndIdx := bytes.Index(tmp, DELIMITER) + if headerEndIdx == -1 { + return handleNoDelimiter(p, tmp, err) + } + + headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx) + + if !isHTTPHeader(headerByte) { + copy(p, tmp) + return read, nil + } + + return hs.processHTTPRequest(p, headerByte, bodyByte) +} + +func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) { + reqhf, err := header.NewRequest(headerByte) + if err != nil { + return 0, err + } + + if err = hs.ApplyRequestMiddlewares(reqhf); err != nil { + return 0, err + } + + hs.reqHeader = reqhf + combined := append(reqhf.Finalize(), bodyByte...) + return copy(p, combined), nil +} + +func handleNoDelimiter(p, tmp []byte, err error) (int, error) { + copy(p, tmp) + return len(tmp), err +} diff --git a/internal/http/stream/stream.go b/internal/http/stream/stream.go new file mode 100644 index 0000000..97d2752 --- /dev/null +++ b/internal/http/stream/stream.go @@ -0,0 +1,103 @@ +package stream + +import ( + "io" + "log" + "net" + "regexp" + "tunnel_pls/internal/http/header" + "tunnel_pls/internal/middleware" +) + +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +type HTTP interface { + io.ReadWriteCloser + CloseWrite() error + RemoteAddr() net.Addr + UseResponseMiddleware(mw middleware.ResponseMiddleware) + UseRequestMiddleware(mw middleware.RequestMiddleware) + SetRequestHeader(header header.RequestHeader) + RequestMiddlewares() []middleware.RequestMiddleware + ResponseMiddlewares() []middleware.ResponseMiddleware + ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error + ApplyRequestMiddlewares(reqhf header.RequestHeader) error +} + +type http struct { + remoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + respHeader header.ResponseHeader + reqHeader header.RequestHeader + respMW []middleware.ResponseMiddleware + reqMW []middleware.RequestMiddleware +} + +func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP { + return &http{ + remoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + } +} + +func (hs *http) RemoteAddr() net.Addr { + return hs.remoteAddr +} + +func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) { + hs.respMW = append(hs.respMW, mw) +} + +func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) { + hs.reqMW = append(hs.reqMW, mw) +} + +func (hs *http) SetRequestHeader(header header.RequestHeader) { + hs.reqHeader = header +} + +func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware { + return hs.reqMW +} + +func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware { + return hs.respMW +} + +func (hs *http) Close() error { + return hs.writer.(io.Closer).Close() +} + +func (hs *http) CloseWrite() error { + if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok { + return closer.CloseWrite() + } + return hs.Close() +} + +func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error { + for _, m := range hs.RequestMiddlewares() { + if err := m.HandleRequest(reqhf); err != nil { + log.Printf("Error when applying request middleware: %v", err) + return err + } + } + return nil +} + +func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error { + for _, m := range hs.ResponseMiddlewares() { + if err := m.HandleResponse(resphf, bodyByte); err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return err + } + } + return nil +} diff --git a/internal/http/stream/writer.go b/internal/http/stream/writer.go new file mode 100644 index 0000000..05e438b --- /dev/null +++ b/internal/http/stream/writer.go @@ -0,0 +1,88 @@ +package stream + +import ( + "bytes" + "tunnel_pls/internal/http/header" +) + +func (hs *http) Write(p []byte) (int, error) { + if hs.shouldBypassBuffering(p) { + hs.respHeader = nil + } + + if hs.respHeader != nil { + return hs.writer.Write(p) + } + + hs.buf = append(hs.buf, p...) + + headerEndIdx := bytes.Index(hs.buf, DELIMITER) + if headerEndIdx == -1 { + return len(p), nil + } + + return hs.processBufferedResponse(p, headerEndIdx) +} + +func (hs *http) shouldBypassBuffering(p []byte) bool { + return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" +} + +func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) { + headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx) + + if !isHTTPHeader(headerByte) { + return hs.writeRawBuffer() + } + + if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil { + return 0, err + } + + hs.buf = nil + return len(p), nil +} + +func (hs *http) writeRawBuffer() (int, error) { + _, err := hs.writer.Write(hs.buf) + length := len(hs.buf) + hs.buf = nil + if err != nil { + return 0, err + } + return length, nil +} + +func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error { + resphf, err := header.NewResponse(headerByte) + if err != nil { + return err + } + + if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil { + return err + } + + hs.respHeader = resphf + finalHeader := resphf.Finalize() + + if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil { + return err + } + + return nil +} + +func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error { + if _, err := hs.writer.Write(header); err != nil { + return err + } + + if len(bodyByte) > 0 { + if _, err := hs.writer.Write(bodyByte); err != nil { + return err + } + } + + return nil +} diff --git a/internal/middleware/forwardedfor.go b/internal/middleware/forwardedfor.go new file mode 100644 index 0000000..6a744a6 --- /dev/null +++ b/internal/middleware/forwardedfor.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net" + "tunnel_pls/internal/http/header" +) + +type ForwardedFor struct { + addr net.Addr +} + +func NewForwardedFor(addr net.Addr) *ForwardedFor { + return &ForwardedFor{addr: addr} +} + +func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error { + host, _, err := net.SplitHostPort(ff.addr.String()) + if err != nil { + return err + } + header.Set("X-Forwarded-For", host) + return nil +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..4e3b163 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,13 @@ +package middleware + +import ( + "tunnel_pls/internal/http/header" +) + +type RequestMiddleware interface { + HandleRequest(header header.RequestHeader) error +} + +type ResponseMiddleware interface { + HandleResponse(header header.ResponseHeader, body []byte) error +} diff --git a/internal/middleware/tunnelfingerprint.go b/internal/middleware/tunnelfingerprint.go new file mode 100644 index 0000000..68171c2 --- /dev/null +++ b/internal/middleware/tunnelfingerprint.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "tunnel_pls/internal/http/header" +) + +type TunnelFingerprint struct{} + +func NewTunnelFingerprint() *TunnelFingerprint { + return &TunnelFingerprint{} +} + +func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error { + header.Set("Server", "Tunnel Please") + return nil +} diff --git a/internal/port/port.go b/internal/port/port.go index 01ecf96..6c60fbb 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -6,37 +6,37 @@ import ( "sync" ) -type Registry interface { - AddPortRange(startPort, endPort uint16) error - GetUnassignedPort() (uint16, bool) - SetPortStatus(port uint16, assigned bool) error - ClaimPort(port uint16) (claimed bool) +type Port interface { + AddRange(startPort, endPort uint16) error + Unassigned() (uint16, bool) + SetStatus(port uint16, assigned bool) error + Claim(port uint16) (claimed bool) } -type registry struct { +type port struct { mu sync.RWMutex ports map[uint16]bool sortedPorts []uint16 } -func New() Registry { - return ®istry{ +func New() Port { + return &port{ ports: make(map[uint16]bool), sortedPorts: []uint16{}, } } -func (pm *registry) AddPortRange(startPort, endPort uint16) error { +func (pm *port) AddRange(startPort, endPort uint16) error { pm.mu.Lock() defer pm.mu.Unlock() if startPort > endPort { return fmt.Errorf("start port cannot be greater than end port") } - for port := startPort; port <= endPort; port++ { - if _, exists := pm.ports[port]; !exists { - pm.ports[port] = false - pm.sortedPorts = append(pm.sortedPorts, port) + for index := startPort; index <= endPort; index++ { + if _, exists := pm.ports[index]; !exists { + pm.ports[index] = false + pm.sortedPorts = append(pm.sortedPorts, index) } } sort.Slice(pm.sortedPorts, func(i, j int) bool { @@ -45,19 +45,19 @@ func (pm *registry) AddPortRange(startPort, endPort uint16) error { return nil } -func (pm *registry) GetUnassignedPort() (uint16, bool) { +func (pm *port) Unassigned() (uint16, bool) { pm.mu.Lock() defer pm.mu.Unlock() - for _, port := range pm.sortedPorts { - if !pm.ports[port] { - return port, true + for _, index := range pm.sortedPorts { + if !pm.ports[index] { + return index, true } } return 0, false } -func (pm *registry) SetPortStatus(port uint16, assigned bool) error { +func (pm *port) SetStatus(port uint16, assigned bool) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -65,7 +65,7 @@ func (pm *registry) SetPortStatus(port uint16, assigned bool) error { return nil } -func (pm *registry) ClaimPort(port uint16) (claimed bool) { +func (pm *port) Claim(port uint16) (claimed bool) { pm.mu.Lock() defer pm.mu.Unlock() diff --git a/session/registry.go b/internal/registry/registry.go similarity index 83% rename from session/registry.go rename to internal/registry/registry.go index 6698cf1..89cac48 100644 --- a/session/registry.go +++ b/internal/registry/registry.go @@ -1,13 +1,25 @@ -package session +package registry import ( "fmt" "sync" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" "tunnel_pls/types" ) type Key = types.SessionKey +type Session interface { + Lifecycle() lifecycle.Lifecycle + Interaction() interaction.Interaction + Forwarder() forwarder.Forwarder + Slug() slug.Slug + Detail() *types.Detail +} + type Registry interface { Get(key Key) (session Session, err error) GetWithUser(user string, key Key) (session Session, err error) @@ -22,6 +34,15 @@ type registry struct { slugIndex map[Key]string } +var ( + ErrSessionNotFound = fmt.Errorf("session not found") + ErrSlugInUse = fmt.Errorf("slug already in use") + ErrInvalidSlug = fmt.Errorf("invalid slug") + ErrForbiddenSlug = fmt.Errorf("forbidden slug") + ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type") + ErrSlugUnchanged = fmt.Errorf("slug is unchanged") +) + func NewRegistry() Registry { return ®istry{ byUser: make(map[string]map[Key]Session), @@ -35,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) { userID, ok := r.slugIndex[key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, ErrSessionNotFound } client, ok := r.byUser[userID][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, ErrSessionNotFound } return client, nil } @@ -51,37 +72,37 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error client, ok := r.byUser[user][key] if !ok { - return nil, fmt.Errorf("session not found") + return nil, ErrSessionNotFound } return client, nil } func (r *registry) Update(user string, oldKey, newKey Key) error { if oldKey.Type != newKey.Type { - return fmt.Errorf("tunnel type cannot change") + return ErrSlugUnchanged } - if newKey.Type != types.HTTP { - return fmt.Errorf("non http tunnel cannot change slug") + if newKey.Type != types.TunnelTypeHTTP { + return ErrSlugChangeNotAllowed } if isForbiddenSlug(newKey.Id) { - return fmt.Errorf("this subdomain is reserved. Please choose a different one") + return ErrForbiddenSlug } if !isValidSlug(newKey.Id) { - return fmt.Errorf("invalid subdomain. Follow the rules") + return ErrInvalidSlug } r.mu.Lock() defer r.mu.Unlock() if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { - return fmt.Errorf("someone already uses this subdomain") + return ErrSlugInUse } client, ok := r.byUser[user][oldKey] if !ok { - return fmt.Errorf("session not found") + return ErrSessionNotFound } delete(r.byUser[user], oldKey) @@ -97,7 +118,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { return nil } -func (r *registry) Register(key Key, session Session) (success bool) { +func (r *registry) Register(key Key, userSession Session) (success bool) { r.mu.Lock() defer r.mu.Unlock() @@ -105,12 +126,12 @@ func (r *registry) Register(key Key, session Session) (success bool) { return false } - userID := session.Lifecycle().User() + userID := userSession.Lifecycle().User() if r.byUser[userID] == nil { r.byUser[userID] = make(map[Key]Session) } - r.byUser[userID][key] = session + r.byUser[userID][key] = userSession r.slugIndex[key] = userID return true } diff --git a/internal/transport/http.go b/internal/transport/http.go new file mode 100644 index 0000000..dd091c3 --- /dev/null +++ b/internal/transport/http.go @@ -0,0 +1,40 @@ +package transport + +import ( + "errors" + "log" + "net" + "tunnel_pls/internal/registry" +) + +type httpServer struct { + handler *httpHandler + port string +} + +func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { + return &httpServer{ + handler: newHTTPHandler(domain, sessionRegistry, redirectTLS), + port: port, + } +} + +func (ht *httpServer) Listen() (net.Listener, error) { + return net.Listen("tcp", ":"+ht.port) +} + +func (ht *httpServer) Serve(listener net.Listener) error { + log.Printf("HTTP server is starting on port %s", ht.port) + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return err + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go ht.handler.handler(conn, false) + } +} diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go new file mode 100644 index 0000000..8bab4a0 --- /dev/null +++ b/internal/transport/httphandler.go @@ -0,0 +1,231 @@ +package transport + +import ( + "bufio" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "strings" + "time" + "tunnel_pls/internal/http/header" + "tunnel_pls/internal/http/stream" + "tunnel_pls/internal/middleware" + "tunnel_pls/internal/registry" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type httpHandler struct { + domain string + sessionRegistry registry.Registry + redirectTLS bool +} + +func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { + return &httpHandler{ + domain: domain, + sessionRegistry: sessionRegistry, + redirectTLS: redirectTLS, + } +} + +func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error { + _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + + fmt.Sprintf("Location: %s", location) + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "\r\n")) + if err != nil { + return err + } + return nil +} + +func (hh *httpHandler) badRequest(conn net.Conn) error { + if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + return err + } + return nil +} + +func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { + defer hh.closeConnection(conn) + + dstReader := bufio.NewReader(conn) + reqhf, err := header.NewRequest(dstReader) + if err != nil { + log.Printf("Error creating request header: %v", err) + return + } + + slug, err := hh.extractSlug(reqhf) + if err != nil { + _ = hh.badRequest(conn) + return + } + + if hh.shouldRedirectToTLS(isTLS) { + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain)) + return + } + + if hh.handlePingRequest(slug, conn) { + return + } + + sshSession, err := hh.getSession(slug) + if err != nil { + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) + return + } + + hw := stream.New(conn, dstReader, conn.RemoteAddr()) + defer func(hw stream.HTTP) { + err = hw.Close() + if err != nil { + log.Printf("Error closing HTTP stream: %v", err) + } + }(hw) + hh.forwardRequest(hw, reqhf, sshSession) +} + +func (hh *httpHandler) closeConnection(conn net.Conn) { + err := conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + log.Printf("Error closing connection: %v", err) + } +} + +func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { + host := strings.Split(reqhf.Value("Host"), ".") + if len(host) < 1 { + return "", errors.New("invalid host") + } + return host[0], nil +} + +func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { + return !isTLS && hh.redirectTLS +} + +func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { + if slug != "ping" { + return false + } + + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) + if err != nil { + log.Println("Failed to write 200 OK:", err) + } + return true +} + +func (hh *httpHandler) getSession(slug string) (registry.Session, error) { + sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ + Id: slug, + Type: types.TunnelTypeHTTP, + }) + if err != nil { + return nil, err + } + return sshSession, nil +} + +func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { + channel, err := hh.openForwardedChannel(hw, sshSession) + if err != nil { + log.Printf("Failed to establish channel: %v", err) + sshSession.Forwarder().WriteBadGatewayResponse(hw) + return + } + + defer func() { + err = channel.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing forwarded channel: %v", err) + } + }() + + hh.setupMiddlewares(hw) + + if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil { + log.Printf("Failed to forward initial request: %v", err) + return + } + sshSession.Forwarder().HandleConnection(hw, channel) +} + +func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) { + payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) + + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + + resultChan := make(chan channelResult, 1) + + go func() { + channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + hh.cleanupUnusedChannel(channel, reqs) + } + }() + + select { + case result := <-resultChan: + if result.err != nil { + return nil, result.err + } + go ssh.DiscardRequests(result.reqs) + return result.channel, nil + case <-time.After(5 * time.Second): + return nil, errors.New("timeout opening forwarded-tcpip channel") + } +} + +func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { + if channel != nil { + if err := channel.Close(); err != nil { + log.Printf("Failed to close unused channel: %v", err) + } + go ssh.DiscardRequests(reqs) + } +} + +func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { + fingerprintMiddleware := middleware.NewTunnelFingerprint() + forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) + + hw.UseResponseMiddleware(fingerprintMiddleware) + hw.UseRequestMiddleware(forwardedForMiddleware) +} + +func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error { + hw.SetRequestHeader(initialRequest) + + if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { + return fmt.Errorf("error applying request middlewares: %w", err) + } + + if _, err := channel.Write(initialRequest.Finalize()); err != nil { + return fmt.Errorf("error writing to channel: %w", err) + } + + return nil +} diff --git a/internal/transport/https.go b/internal/transport/https.go new file mode 100644 index 0000000..88ffe27 --- /dev/null +++ b/internal/transport/https.go @@ -0,0 +1,45 @@ +package transport + +import ( + "crypto/tls" + "errors" + "log" + "net" + "tunnel_pls/internal/registry" +) + +type https struct { + tlsConfig *tls.Config + httpHandler *httpHandler + domain string + port string +} + +func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport { + return &https{ + tlsConfig: tlsConfig, + httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS), + domain: domain, + port: port, + } +} + +func (ht *https) Listen() (net.Listener, error) { + return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig) +} + +func (ht *https) Serve(listener net.Listener) error { + log.Printf("HTTPS server is starting on port %s", ht.port) + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return err + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go ht.httpHandler.handler(conn, true) + } +} diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go new file mode 100644 index 0000000..99670d2 --- /dev/null +++ b/internal/transport/tcp.go @@ -0,0 +1,66 @@ +package transport + +import ( + "errors" + "fmt" + "io" + "log" + "net" + + "golang.org/x/crypto/ssh" +) + +type tcp struct { + port uint16 + forwarder forwarder +} + +type forwarder interface { + CreateForwardedTCPIPPayload(origin net.Addr) []byte + OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) + HandleConnection(dst io.ReadWriter, src ssh.Channel) +} + +func NewTCPServer(port uint16, forwarder forwarder) Transport { + return &tcp{ + port: port, + forwarder: forwarder, + } +} + +func (tt *tcp) Listen() (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port)) +} + +func (tt *tcp) Serve(listener net.Listener) error { + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + log.Printf("Error accepting connection: %v", err) + continue + } + go tt.handleTcp(conn) + } +} + +func (tt *tcp) handleTcp(conn net.Conn) { + defer func() { + err := conn.Close() + if err != nil { + log.Printf("Failed to close connection: %v", err) + } + }() + payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr()) + channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + + return + } + + go ssh.DiscardRequests(reqs) + tt.forwarder.HandleConnection(conn, channel) +} diff --git a/server/tls.go b/internal/transport/tls.go similarity index 84% rename from server/tls.go rename to internal/transport/tls.go index fc67733..877afb4 100644 --- a/server/tls.go +++ b/internal/transport/tls.go @@ -1,4 +1,4 @@ -package server +package transport import ( "context" @@ -26,7 +26,8 @@ type TLSManager interface { } type tlsManager struct { - domain string + config config.Config + certPath string keyPath string storagePath string @@ -42,7 +43,7 @@ type tlsManager struct { var globalTLSManager TLSManager var tlsManagerOnce sync.Once -func NewTLSConfig(domain string) (*tls.Config, error) { +func NewTLSConfig(config config.Config) (*tls.Config, error) { var initErr error tlsManagerOnce.Do(func() { @@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { storagePath := "certs/tls/certmagic" tm := &tlsManager{ - domain: domain, + config: config, certPath: certPath, keyPath: keyPath, storagePath: storagePath, @@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { tm.useCertMagic = false tm.startCertWatcher() } else { - if !isACMEConfigComplete() { - log.Printf("User certificates missing or invalid, and ACME configuration is incomplete") - log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable") - initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)") - return - } - - log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain) + log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain()) if err := tm.initCertMagic(); err != nil { initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) return @@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) { return globalTLSManager.getTLSConfig(), nil } -func isACMEConfigComplete() bool { - cfAPIToken := config.Getenv("CF_API_TOKEN", "") - return cfAPIToken != "" -} - func (tm *tlsManager) userCertsExistAndValid() bool { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { log.Printf("Certificate file not found: %s", tm.certPath) @@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool { return false } - return ValidateCertDomains(tm.certPath, tm.domain) + return ValidateCertDomains(tm.certPath, tm.config.Domain()) } func ValidateCertDomains(certPath, domain string) bool { @@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() { if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { log.Printf("Certificate files changed, reloading...") - if !ValidateCertDomains(tm.certPath, tm.domain) { + if !ValidateCertDomains(tm.certPath, tm.config.Domain()) { log.Printf("New certificates don't cover required domains") - if !isACMEConfigComplete() { - log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)") - continue - } - - log.Printf("Switching to CertMagic for automatic certificate management") if err := tm.initCertMagic(); err != nil { log.Printf("Failed to initialize CertMagic: %v", err) continue @@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error { return fmt.Errorf("failed to create cert storage directory: %w", err) } - acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain) - cfAPIToken := config.Getenv("CF_API_TOKEN", "") - acmeStaging := config.Getenv("ACME_STAGING", "false") == "true" - - if cfAPIToken == "" { + if tm.config.CFAPIToken() == "" { return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") } cfProvider := &cloudflare.Provider{ - APIToken: cfAPIToken, + APIToken: tm.config.CFAPIToken(), } storage := &certmagic.FileStorage{Path: tm.storagePath} @@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error { }) acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ - Email: acmeEmail, + Email: tm.config.ACMEEmail(), Agreed: true, DNS01Solver: &certmagic.DNS01Solver{ DNSManager: certmagic.DNSManager{ @@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error { }, }) - if acmeStaging { + if tm.config.ACMEStaging() { acmeIssuer.CA = certmagic.LetsEncryptStagingCA log.Printf("Using Let's Encrypt staging server") } else { @@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error { magic.Issuers = []certmagic.Issuer{acmeIssuer} tm.magic = magic - domains := []string{tm.domain, "*." + tm.domain} + domains := []string{tm.config.Domain(), "*." + tm.config.Domain()} log.Printf("Requesting certificates for: %v", domains) ctx := context.Background() diff --git a/internal/transport/transport.go b/internal/transport/transport.go new file mode 100644 index 0000000..ca27061 --- /dev/null +++ b/internal/transport/transport.go @@ -0,0 +1,10 @@ +package transport + +import ( + "net" +) + +type Transport interface { + Listen() (net.Listener, error) + Serve(listener net.Listener) error +} diff --git a/version/version.go b/internal/version/version.go similarity index 100% rename from version/version.go rename to internal/version/version.go diff --git a/main.go b/main.go index 2303718..f897b46 100644 --- a/main.go +++ b/main.go @@ -4,21 +4,22 @@ import ( "context" "fmt" "log" + "net" "net/http" _ "net/http/pprof" "os" "os/signal" - "strconv" - "strings" "syscall" "time" "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/key" "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" + "tunnel_pls/internal/version" "tunnel_pls/server" - "tunnel_pls/session" - "tunnel_pls/version" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) @@ -34,27 +35,12 @@ func main() { log.Printf("Starting %s", version.GetVersion()) - err := config.Load() + conf, err := config.MustLoad() if err != nil { log.Fatalf("Failed to load configuration: %s", err) return } - mode := strings.ToLower(config.Getenv("MODE", "standalone")) - isNodeMode := mode == "node" - - pprofEnabled := config.Getenv("PPROF_ENABLED", "false") - if pprofEnabled == "true" { - pprofPort := config.Getenv("PPROF_PORT", "6060") - go func() { - pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) - log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) - if err = http.ListenAndServe(pprofAddr, nil); err != nil { - log.Printf("pprof server error: %v", err) - } - }() - } - sshConfig := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), @@ -76,7 +62,7 @@ func main() { } sshConfig.AddHostKey(private) - sessionRegistry := session.NewRegistry() + sessionRegistry := registry.NewRegistry() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -86,16 +72,11 @@ func main() { signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) var grpcClient client.Client - if isNodeMode { - grpcHost := config.Getenv("GRPC_ADDRESS", "localhost") - grpcPort := config.Getenv("GRPC_PORT", "8080") - grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort) - nodeToken := config.Getenv("NODE_TOKEN", "") - if nodeToken == "" { - log.Fatalf("NODE_TOKEN is required in node mode") - } - grpcClient, err = client.New(grpcAddr, sessionRegistry) + if conf.Mode() == types.ServerModeNODE { + grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort()) + + grpcClient, err = client.New(conf, grpcAddr, sessionRegistry) if err != nil { log.Fatalf("failed to create grpc client: %v", err) } @@ -108,47 +89,73 @@ func main() { healthCancel() go func() { - identity := config.Getenv("DOMAIN", "localhost") - if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil { + if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { errChan <- fmt.Errorf("failed to subscribe to events: %w", err) } }() } - portManager := port.New() - rawRange := config.Getenv("ALLOWED_PORTS", "") - if rawRange != "" { - splitRange := strings.Split(rawRange, "-") - if len(splitRange) == 2 { - var start, end uint64 - start, err = strconv.ParseUint(splitRange[0], 10, 16) - if err != nil { - log.Fatalf("Failed to parse start port: %s", err) - } - - end, err = strconv.ParseUint(splitRange[1], 10, 16) - if err != nil { - log.Fatalf("Failed to parse end port: %s", err) - } - - if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil { - log.Fatalf("Failed to add port range: %s", err) - } - log.Printf("PortRegistry range configured: %d-%d", start, end) - } else { - log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) + go func() { + var httpListener net.Listener + httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect()) + httpListener, err = httpserver.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return } + err = httpserver.Serve(httpListener) + if err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + return + } + }() + + if conf.TLSEnabled() { + go func() { + var httpsListener net.Listener + tlsConfig, _ := transport.NewTLSConfig(conf) + httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig) + httpsListener, err = httpsServer.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return + } + err = httpsServer.Serve(httpsListener) + if err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + return + } + }() } + + portManager := port.New() + err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()) + if err != nil { + log.Fatalf("Failed to initialize port manager: %s", err) + return + } + var app server.Server go func() { - app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager) + app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort()) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return } app.Start() + }() + if conf.PprofEnabled() { + go func() { + pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort()) + log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) + if err = http.ListenAndServe(pprofAddr, nil); err != nil { + log.Printf("pprof server error: %v", err) + } + }() + } + select { case err = <-errChan: log.Printf("error happen : %s", err) diff --git a/server/header.go b/server/header.go deleted file mode 100644 index 584394b..0000000 --- a/server/header.go +++ /dev/null @@ -1,276 +0,0 @@ -package server - -import ( - "bufio" - "bytes" - "fmt" -) - -type HeaderManager interface { - Get(key string) []byte - Set(key string, value []byte) - Remove(key string) - Finalize() []byte -} - -type ResponseHeaderManager interface { - Get(key string) string - Set(key string, value string) - Remove(key string) - Finalize() []byte -} - -type RequestHeaderManager interface { - Get(key string) string - Set(key string, value string) - Remove(key string) - Finalize() []byte - GetMethod() string - GetPath() string - GetVersion() string -} - -type responseHeaderFactory struct { - startLine []byte - headers map[string]string -} - -type requestHeaderFactory struct { - method string - path string - version string - startLine []byte - headers map[string]string -} - -func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) { - switch v := r.(type) { - case []byte: - return parseHeadersFromBytes(v) - case *bufio.Reader: - return parseHeadersFromReader(v) - default: - return nil, fmt.Errorf("unsupported type: %T", r) - } -} - -func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) { - header := &requestHeaderFactory{ - headers: make(map[string]string, 16), - } - - lineEnd := bytes.IndexByte(headerData, '\n') - if lineEnd == -1 { - return nil, fmt.Errorf("invalid request: no newline found") - } - - startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n") - header.startLine = make([]byte, len(startLine)) - copy(header.startLine, startLine) - - parts := bytes.Split(startLine, []byte{' '}) - if len(parts) < 3 { - return nil, fmt.Errorf("invalid request line") - } - - header.method = string(parts[0]) - header.path = string(parts[1]) - header.version = string(parts[2]) - - remaining := headerData[lineEnd+1:] - - for len(remaining) > 0 { - lineEnd = bytes.IndexByte(remaining, '\n') - if lineEnd == -1 { - lineEnd = len(remaining) - } - - line := bytes.TrimRight(remaining[:lineEnd], "\r\n") - - if len(line) == 0 { - break - } - - colonIdx := bytes.IndexByte(line, ':') - if colonIdx != -1 { - key := bytes.TrimSpace(line[:colonIdx]) - value := bytes.TrimSpace(line[colonIdx+1:]) - header.headers[string(key)] = string(value) - } - - if lineEnd == len(remaining) { - break - } - remaining = remaining[lineEnd+1:] - } - - return header, nil -} - -func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) { - header := &requestHeaderFactory{ - headers: make(map[string]string, 16), - } - - startLineBytes, err := br.ReadSlice('\n') - if err != nil { - if err == bufio.ErrBufferFull { - var startLine string - startLine, err = br.ReadString('\n') - if err != nil { - return nil, err - } - startLineBytes = []byte(startLine) - } else { - return nil, err - } - } - - startLineBytes = bytes.TrimRight(startLineBytes, "\r\n") - header.startLine = make([]byte, len(startLineBytes)) - copy(header.startLine, startLineBytes) - - parts := bytes.Split(startLineBytes, []byte{' '}) - if len(parts) < 3 { - return nil, fmt.Errorf("invalid request line") - } - - header.method = string(parts[0]) - header.path = string(parts[1]) - header.version = string(parts[2]) - - for { - lineBytes, err := br.ReadSlice('\n') - if err != nil { - if err == bufio.ErrBufferFull { - var line string - line, err = br.ReadString('\n') - if err != nil { - return nil, err - } - lineBytes = []byte(line) - } else { - return nil, err - } - } - - lineBytes = bytes.TrimRight(lineBytes, "\r\n") - - if len(lineBytes) == 0 { - break - } - - colonIdx := bytes.IndexByte(lineBytes, ':') - if colonIdx == -1 { - continue - } - - key := bytes.TrimSpace(lineBytes[:colonIdx]) - value := bytes.TrimSpace(lineBytes[colonIdx+1:]) - - header.headers[string(key)] = string(value) - } - - return header, nil -} - -func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager { - header := &responseHeaderFactory{ - startLine: nil, - headers: make(map[string]string), - } - lines := bytes.Split(startLine, []byte("\r\n")) - if len(lines) == 0 { - return header - } - header.startLine = lines[0] - for _, h := range lines[1:] { - if len(h) == 0 { - continue - } - - parts := bytes.SplitN(h, []byte(":"), 2) - if len(parts) < 2 { - continue - } - - key := parts[0] - val := bytes.TrimSpace(parts[1]) - header.headers[string(key)] = string(val) - } - return header -} - -func (resp *responseHeaderFactory) Get(key string) string { - return resp.headers[key] -} - -func (resp *responseHeaderFactory) Set(key string, value string) { - resp.headers[key] = value -} - -func (resp *responseHeaderFactory) Remove(key string) { - delete(resp.headers, key) -} - -func (resp *responseHeaderFactory) Finalize() []byte { - var buf bytes.Buffer - - buf.Write(resp.startLine) - buf.WriteString("\r\n") - - for key, val := range resp.headers { - buf.WriteString(key) - buf.WriteString(": ") - buf.WriteString(val) - buf.WriteString("\r\n") - } - - buf.WriteString("\r\n") - return buf.Bytes() -} - -func (req *requestHeaderFactory) Get(key string) string { - val, ok := req.headers[key] - if !ok { - return "" - } - return val -} - -func (req *requestHeaderFactory) Set(key string, value string) { - req.headers[key] = value -} - -func (req *requestHeaderFactory) Remove(key string) { - delete(req.headers, key) -} - -func (req *requestHeaderFactory) GetMethod() string { - return req.method -} - -func (req *requestHeaderFactory) GetPath() string { - return req.path -} - -func (req *requestHeaderFactory) GetVersion() string { - return req.version -} - -func (req *requestHeaderFactory) Finalize() []byte { - var buf bytes.Buffer - - buf.Write(req.startLine) - buf.WriteString("\r\n") - - for key, val := range req.headers { - buf.WriteString(key) - buf.WriteString(": ") - buf.WriteString(val) - buf.WriteString("\r\n") - } - - buf.WriteString("\r\n") - return buf.Bytes() -} diff --git a/server/http.go b/server/http.go deleted file mode 100644 index e2143d5..0000000 --- a/server/http.go +++ /dev/null @@ -1,395 +0,0 @@ -package server - -import ( - "bufio" - "bytes" - "errors" - "fmt" - "io" - "log" - "net" - "regexp" - "strings" - "time" - "tunnel_pls/internal/config" - "tunnel_pls/session" - "tunnel_pls/types" - - "golang.org/x/crypto/ssh" -) - -type HTTPWriter interface { - io.Reader - io.Writer - GetRemoteAddr() net.Addr - GetWriter() io.Writer - AddResponseMiddleware(mw ResponseMiddleware) - AddRequestStartMiddleware(mw RequestMiddleware) - SetRequestHeader(header RequestHeaderManager) - GetRequestStartMiddleware() []RequestMiddleware -} - -type customWriter struct { - remoteAddr net.Addr - writer io.Writer - reader io.Reader - headerBuf []byte - buf []byte - respHeader ResponseHeaderManager - reqHeader RequestHeaderManager - respMW []ResponseMiddleware - reqStartMW []RequestMiddleware - reqEndMW []RequestMiddleware -} - -func (cw *customWriter) GetRemoteAddr() net.Addr { - return cw.remoteAddr -} - -func (cw *customWriter) GetWriter() io.Writer { - return cw.writer -} - -func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) { - cw.respMW = append(cw.respMW, mw) -} - -func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) { - cw.reqStartMW = append(cw.reqStartMW, mw) -} - -func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) { - cw.reqHeader = header -} - -func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware { - return cw.reqStartMW -} - -func (cw *customWriter) Read(p []byte) (int, error) { - tmp := make([]byte, len(p)) - read, err := cw.reader.Read(tmp) - if read == 0 && err != nil { - return 0, err - } - - tmp = tmp[:read] - - idx := bytes.Index(tmp, DELIMITER) - if idx == -1 { - copy(p, tmp) - if err != nil { - return read, err - } - return read, nil - } - - header := tmp[:idx+len(DELIMITER)] - body := tmp[idx+len(DELIMITER):] - - if !isHTTPHeader(header) { - copy(p, tmp) - return read, nil - } - - for _, m := range cw.reqEndMW { - err = m.HandleRequest(cw.reqHeader) - if err != nil { - log.Printf("Error when applying request middleware: %v", err) - return 0, err - } - } - - reqhf, err := NewRequestHeaderFactory(header) - if err != nil { - return 0, err - } - - for _, m := range cw.reqStartMW { - if mwErr := m.HandleRequest(reqhf); mwErr != nil { - log.Printf("Error when applying request middleware: %v", mwErr) - return 0, mwErr - } - } - - cw.reqHeader = reqhf - finalHeader := reqhf.Finalize() - - combined := append(finalHeader, body...) - - n := copy(p, combined) - - return n, nil -} - -func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { - return &customWriter{ - remoteAddr: remoteAddr, - writer: writer, - reader: reader, - buf: make([]byte, 0, 4096), - } -} - -var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} -var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) -var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) - -func isHTTPHeader(buf []byte) bool { - lines := bytes.Split(buf, []byte("\r\n")) - - startLine := string(lines[0]) - if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { - return false - } - - for _, line := range lines[1:] { - if len(line) == 0 { - break - } - colonIdx := bytes.IndexByte(line, ':') - if colonIdx <= 0 { - return false - } - } - return true -} - -func (cw *customWriter) Write(p []byte) (int, error) { - if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" { - cw.respHeader = nil - } - - if cw.respHeader != nil { - n, err := cw.writer.Write(p) - if err != nil { - return n, err - } - return n, nil - } - - cw.buf = append(cw.buf, p...) - - idx := bytes.Index(cw.buf, DELIMITER) - if idx == -1 { - return len(p), nil - } - - header := cw.buf[:idx+len(DELIMITER)] - body := cw.buf[idx+len(DELIMITER):] - - if !isHTTPHeader(header) { - _, err := cw.writer.Write(cw.buf) - cw.buf = nil - if err != nil { - return 0, err - } - return len(p), nil - } - - resphf := NewResponseHeaderFactory(header) - for _, m := range cw.respMW { - err := m.HandleResponse(resphf, body) - if err != nil { - log.Printf("Cannot apply middleware: %s\n", err) - return 0, err - } - } - header = resphf.Finalize() - cw.respHeader = resphf - - _, err := cw.writer.Write(header) - if err != nil { - return 0, err - } - if len(body) > 0 { - _, err = cw.writer.Write(body) - if err != nil { - return 0, err - } - } - cw.buf = nil - return len(p), nil -} - -var redirectTLS = false - -type HTTPServer interface { - ListenAndServe() error - ListenAndServeTLS() error - handler(conn net.Conn) - handlerTLS(conn net.Conn) -} -type httpServer struct { - sessionRegistry session.Registry -} - -func NewHTTPServer(sessionRegistry session.Registry) HTTPServer { - return &httpServer{sessionRegistry: sessionRegistry} -} - -func (hs *httpServer) ListenAndServe() error { - httpPort := config.Getenv("HTTP_PORT", "8080") - listener, err := net.Listen("tcp", ":"+httpPort) - if err != nil { - return errors.New("Error listening: " + err.Error()) - } - if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" { - redirectTLS = true - } - go func() { - for { - var conn net.Conn - conn, err = listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handler(conn) - } - }() - return nil -} - -func (hs *httpServer) handler(conn net.Conn) { - defer func() { - err := conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - log.Printf("Error closing connection: %v", err) - return - } - return - }() - - dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeaderFactory(dstReader) - if err != nil { - log.Printf("Error creating request header: %v", err) - return - } - - host := strings.Split(reqhf.Get("Host"), ".") - if len(host) < 1 { - _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - if err != nil { - log.Println("Failed to write 400 Bad Request:", err) - return - } - return - } - - slug := host[0] - - if redirectTLS { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return - } - - if slug == "ping" { - _, err = conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - return - } - return - } - - sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.HTTP, - }) - if err != nil { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return - } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - forwardRequest(cw, reqhf, sshSession) - return -} - -func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - resultChan <- channelResult{channel, reqs, err} - }() - - var channel ssh.Channel - var reqs <-chan *ssh.Request - - select { - case result := <-resultChan: - if result.err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) - return - } - channel = result.channel - reqs = result.reqs - case <-time.After(5 * time.Second): - log.Printf("Timeout opening forwarded-tcpip channel") - sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) - return - } - - go ssh.DiscardRequests(reqs) - - fingerprintMiddleware := NewTunnelFingerprint() - forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr()) - - cw.AddResponseMiddleware(fingerprintMiddleware) - cw.AddRequestStartMiddleware(forwardedForMiddleware) - cw.SetRequestHeader(initialRequest) - - for _, m := range cw.GetRequestStartMiddleware() { - if err := m.HandleRequest(initialRequest); err != nil { - log.Printf("Error handling request: %v", err) - return - } - } - - _, err := channel.Write(initialRequest.Finalize()) - if err != nil { - log.Printf("Failed to forward request: %v", err) - return - } - - sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) - return -} diff --git a/server/https.go b/server/https.go deleted file mode 100644 index 2758172..0000000 --- a/server/https.go +++ /dev/null @@ -1,112 +0,0 @@ -package server - -import ( - "bufio" - "crypto/tls" - "errors" - "fmt" - "log" - "net" - "strings" - "tunnel_pls/internal/config" - "tunnel_pls/types" -) - -func (hs *httpServer) ListenAndServeTLS() error { - domain := config.Getenv("DOMAIN", "localhost") - httpsPort := config.Getenv("HTTPS_PORT", "8443") - - tlsConfig, err := NewTLSConfig(domain) - if err != nil { - return fmt.Errorf("failed to initialize TLS config: %w", err) - } - - ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig) - if err != nil { - return err - } - - go func() { - for { - var conn net.Conn - conn, err = ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - log.Println("https server closed") - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handlerTLS(conn) - } - }() - return nil -} - -func (hs *httpServer) handlerTLS(conn net.Conn) { - defer func() { - err := conn.Close() - if err != nil { - log.Printf("Error closing connection: %v", err) - return - } - return - }() - - dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeaderFactory(dstReader) - if err != nil { - log.Printf("Error creating request header: %v", err) - return - } - - host := strings.Split(reqhf.Get("Host"), ".") - if len(host) < 1 { - _, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - if err != nil { - log.Println("Failed to write 400 Bad Request:", err) - return - } - return - } - - slug := host[0] - - if slug == "ping" { - _, err = conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - return - } - return - } - - sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.HTTP, - }) - if err != nil { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return - } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - forwardRequest(cw, reqhf, sshSession) - return -} diff --git a/server/middleware.go b/server/middleware.go deleted file mode 100644 index ee6ca1a..0000000 --- a/server/middleware.go +++ /dev/null @@ -1,41 +0,0 @@ -package server - -import ( - "net" -) - -type RequestMiddleware interface { - HandleRequest(header RequestHeaderManager) error -} - -type ResponseMiddleware interface { - HandleResponse(header ResponseHeaderManager, body []byte) error -} - -type TunnelFingerprint struct{} - -func NewTunnelFingerprint() *TunnelFingerprint { - return &TunnelFingerprint{} -} - -func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error { - header.Set("Server", "Tunnel Please") - return nil -} - -type ForwardedFor struct { - addr net.Addr -} - -func NewForwardedFor(addr net.Addr) *ForwardedFor { - return &ForwardedFor{addr: addr} -} - -func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error { - host, _, err := net.SplitHostPort(ff.addr.String()) - if err != nil { - return err - } - header.Set("X-Forwarded-For", host) - return nil -} diff --git a/server/server.go b/server/server.go index 792f47e..f47c579 100644 --- a/server/server.go +++ b/server/server.go @@ -10,6 +10,7 @@ import ( "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" "tunnel_pls/session" "golang.org/x/crypto/ssh" @@ -20,38 +21,26 @@ type Server interface { Close() error } type server struct { - listener net.Listener - config *ssh.ServerConfig + config config.Config + sshPort string + sshListener net.Listener + sshConfig *ssh.ServerConfig grpcClient client.Client - sessionRegistry session.Registry - portRegistry port.Registry + sessionRegistry registry.Registry + portRegistry port.Port } -func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) { - listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) +func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { + listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) if err != nil { - log.Fatalf("failed to listen on port 2200: %v", err) return nil, err } - HttpServer := NewHTTPServer(sessionRegistry) - err = HttpServer.ListenAndServe() - if err != nil { - log.Fatalf("failed to start http server: %v", err) - return nil, err - } - - if config.Getenv("TLS_ENABLED", "false") == "true" { - err = HttpServer.ListenAndServeTLS() - if err != nil { - log.Fatalf("failed to start https server: %v", err) - return nil, err - } - } - return &server{ - listener: listener, - config: sshConfig, + config: config, + sshPort: sshPort, + sshListener: listener, + sshConfig: sshConfig, grpcClient: grpcClient, sessionRegistry: sessionRegistry, portRegistry: portRegistry, @@ -59,9 +48,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie } func (s *server) Start() { - log.Println("SSH server is starting on port 2200...") + log.Printf("SSH server is starting on port %s", s.sshPort) for { - conn, err := s.listener.Accept() + conn, err := s.sshListener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { log.Println("listener closed, stopping server") @@ -76,11 +65,11 @@ func (s *server) Start() { } func (s *server) Close() error { - return s.listener.Close() + return s.sshListener.Close() } func (s *server) handleConnection(conn net.Conn) { - sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) + sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig) if err != nil { log.Printf("failed to establish SSH connection: %v", err) err = conn.Close() @@ -106,7 +95,7 @@ func (s *server) handleConnection(conn net.Conn) { cancel() } log.Println("SSH connection established:", sshConn.User()) - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) + sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 5807ac4..c602565 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "log" "net" @@ -17,37 +18,6 @@ import ( "golang.org/x/crypto/ssh" ) -var bufferPool = sync.Pool{ - New: func() interface{} { - bufSize := config.GetBufferSize() - return make([]byte, bufSize) - }, -} - -func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { - buf := bufferPool.Get().([]byte) - defer bufferPool.Put(buf) - return io.CopyBuffer(dst, src, buf) -} - -type forwarder struct { - listener net.Listener - tunnelType types.TunnelType - forwardedPort uint16 - slug slug.Slug - conn ssh.Conn -} - -func New(slug slug.Slug, conn ssh.Conn) Forwarder { - return &forwarder{ - listener: nil, - tunnelType: types.UNKNOWN, - forwardedPort: 0, - slug: slug, - conn: conn, - } -} - type Forwarder interface { SetType(tunnelType types.TunnelType) SetForwardedPort(port uint16) @@ -55,110 +25,124 @@ type Forwarder interface { Listener() net.Listener TunnelType() types.TunnelType ForwardedPort() uint16 - HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) + HandleConnection(dst io.ReadWriter, src ssh.Channel) CreateForwardedTCPIPPayload(origin net.Addr) []byte + OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) WriteBadGatewayResponse(dst io.Writer) - AcceptTCPConnections() Close() error } +type forwarder struct { + listener net.Listener + tunnelType types.TunnelType + forwardedPort uint16 + slug slug.Slug + conn ssh.Conn + bufferPool sync.Pool +} -func (f *forwarder) AcceptTCPConnections() { - for { - conn, err := f.Listener().Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - - if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Printf("Failed to set connection deadline: %v", err) - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - continue - } - - payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) - resultChan <- channelResult{channel, reqs, err} - }() - - select { - case result := <-resultChan: - if result.err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - continue - } - - if err = conn.SetDeadline(time.Time{}); err != nil { - log.Printf("Failed to clear connection deadline: %v", err) - } - - go ssh.DiscardRequests(result.reqs) - go f.HandleConnection(conn, result.channel, conn.RemoteAddr()) - - case <-time.After(5 * time.Second): - log.Printf("Timeout opening forwarded-tcpip channel") - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - } +func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder { + return &forwarder{ + listener: nil, + tunnelType: types.TunnelTypeUNKNOWN, + forwardedPort: 0, + slug: slug, + conn: conn, + bufferPool: sync.Pool{ + New: func() interface{} { + bufSize := config.BufferSize() + return make([]byte, bufSize) + }, + }, } } -func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { +func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { + buf := f.bufferPool.Get().([]byte) + defer f.bufferPool.Put(buf) + return io.CopyBuffer(dst, src, buf) +} + +func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + resultChan := make(chan channelResult, 1) + + go func() { + channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + if channel != nil { + err = channel.Close() + if err != nil { + log.Printf("Failed to close unused channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + } + } + }() + + select { + case result := <-resultChan: + return result.channel, result.reqs, result.err + case <-time.After(5 * time.Second): + return nil, nil, errors.New("timeout opening forwarded-tcpip channel") + } +} + +func closeWriter(w io.Writer) error { + if cw, ok := w.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + if closer, ok := w.(io.Closer); ok { + return closer.Close() + } + return nil +} + +func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error { + var errs []error + _, err := f.copyWithBuffer(dst, src) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err)) + } + + if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) { + errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err)) + } + return errors.Join(errs...) +} + +func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { defer func() { _, err := io.Copy(io.Discard, src) if err != nil { log.Printf("Failed to discard connection: %v", err) } - - err = src.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing source channel: %v", err) - } - - if closer, ok := dst.(io.Closer); ok { - err = closer.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing destination connection: %v", err) - } - } }() - log.Printf("Handling new forwarded connection from %s", remoteAddr) - var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() - _, err := copyWithBuffer(dst, src) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying src→dst: %v", err) + err := f.copyAndClose(dst, src, "src to dst") + if err != nil { + log.Println("Error during copy: ", err) + return } }() go func() { defer wg.Done() - _, err := copyWithBuffer(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying dst→src: %v", err) + err := f.copyAndClose(src, dst, "dst to src") + if err != nil { + log.Println("Error during copy: ", err) + return } }() diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 3c02dae..5f68102 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -18,9 +18,9 @@ import ( ) type Interaction interface { - Mode() types.Mode + Mode() types.InteractiveMode SetChannel(channel ssh.Channel) - SetMode(m types.Mode) + SetMode(m types.InteractiveMode) SetWH(w, h int) Start() Redraw() @@ -39,6 +39,7 @@ type Forwarder interface { type CloseFunc func() error type interaction struct { + config config.Config channel ssh.Channel slug slug.Slug forwarder Forwarder @@ -48,14 +49,14 @@ type interaction struct { program *tea.Program ctx context.Context cancel context.CancelFunc - mode types.Mode + mode types.InteractiveMode } -func (i *interaction) SetMode(m types.Mode) { +func (i *interaction) SetMode(m types.InteractiveMode) { i.mode = m } -func (i *interaction) Mode() types.Mode { +func (i *interaction) Mode() types.InteractiveMode { return i.mode } @@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) { } } -func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { +func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { ctx, cancel := context.WithCancel(context.Background()) return &interaction{ + config: config, channel: nil, slug: slug, forwarder: forwarder, @@ -174,14 +176,13 @@ func (m *model) View() string { } func (i *interaction) Start() { - if i.mode == types.HEADLESS { + if i.mode == types.InteractiveModeHEADLESS { return } lipgloss.SetColorProfile(termenv.TrueColor) - domain := config.Getenv("DOMAIN", "localhost") protocol := "http" - if config.Getenv("TLS_ENABLED", "false") == "true" { + if i.config.TLSEnabled() { protocol = "https" } @@ -209,7 +210,7 @@ func (i *interaction) Start() { ti.Width = 50 m := &model{ - domain: domain, + domain: i.config.Domain(), protocol: protocol, tunnelType: tunnelType, port: port, diff --git a/session/interaction/model.go b/session/interaction/model.go index 24b4d26..189b0a1 100644 --- a/session/interaction/model.go +++ b/session/interaction/model.go @@ -41,7 +41,7 @@ type model struct { } func (m *model) getTunnelURL() string { - if m.tunnelType == types.HTTP { + if m.tunnelType == types.TunnelTypeHTTP { return buildURL(m.protocol, m.interaction.slug.String(), m.domain) } return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) diff --git a/session/interaction/slug.go b/session/interaction/slug.go index 6c6a97b..2b871d4 100644 --- a/session/interaction/slug.go +++ b/session/interaction/slug.go @@ -15,7 +15,7 @@ import ( func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { var cmd tea.Cmd - if m.tunnelType != types.HTTP { + if m.tunnelType != types.TunnelTypeHTTP { m.editingSlug = false m.slugError = "" return m, tea.Batch(tea.ClearScreen, textinput.Blink) @@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { inputValue := m.slugInput.Value() if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{ Id: m.interaction.slug.String(), - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }, types.SessionKey{ Id: inputValue, - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }); err != nil { m.slugError = err.Error() return m, nil @@ -130,7 +130,7 @@ func (m *model) slugView() string { b.WriteString(titleStyle.Render(title)) b.WriteString("\n\n") - if m.tunnelType != types.HTTP { + if m.tunnelType != types.TunnelTypeHTTP { warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) warningBoxStyle := lipgloss.NewStyle(). Foreground(lipgloss.Color("#FFA500")). diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 8d134a2..f9f9d6e 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -2,8 +2,6 @@ package lifecycle import ( "errors" - "io" - "net" "time" portUtil "tunnel_pls/internal/port" @@ -24,20 +22,20 @@ type SessionRegistry interface { } type lifecycle struct { - status types.Status + status types.SessionStatus conn ssh.Conn channel ssh.Channel forwarder Forwarder slug slug.Slug startedAt time.Time sessionRegistry SessionRegistry - portRegistry portUtil.Registry + portRegistry portUtil.Port user string } -func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle { +func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle { return &lifecycle{ - status: types.INITIALIZING, + status: types.SessionStatusINITIALIZING, conn: conn, channel: nil, forwarder: forwarder, @@ -51,16 +49,16 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti type Lifecycle interface { Connection() ssh.Conn - PortRegistry() portUtil.Registry + PortRegistry() portUtil.Port User() string SetChannel(channel ssh.Channel) - SetStatus(status types.Status) + SetStatus(status types.SessionStatus) IsActive() bool StartedAt() time.Time Close() error } -func (l *lifecycle) PortRegistry() portUtil.Registry { +func (l *lifecycle) PortRegistry() portUtil.Port { return l.portRegistry } @@ -74,35 +72,30 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) { func (l *lifecycle) Connection() ssh.Conn { return l.conn } -func (l *lifecycle) SetStatus(status types.Status) { +func (l *lifecycle) SetStatus(status types.SessionStatus) { l.status = status - if status == types.RUNNING && l.startedAt.IsZero() { + if status == types.SessionStatusRUNNING && l.startedAt.IsZero() { l.startedAt = time.Now() } } +func closeIfNotNil(c interface{ Close() error }) error { + if c != nil { + return c.Close() + } + return nil +} + func (l *lifecycle) Close() error { - var firstErr error + var errs []error tunnelType := l.forwarder.TunnelType() - if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - firstErr = err + if err := closeIfNotNil(l.channel); err != nil { + errs = append(errs, err) } - if l.channel != nil { - if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - if firstErr == nil { - firstErr = err - } - } - } - - if l.conn != nil { - if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - if firstErr == nil { - firstErr = err - } - } + if err := closeIfNotNil(l.conn); err != nil { + errs = append(errs, err) } clientSlug := l.slug.String() @@ -112,17 +105,20 @@ func (l *lifecycle) Close() error { } l.sessionRegistry.Remove(key) - if tunnelType == types.TCP { - if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { - firstErr = err + if tunnelType == types.TunnelTypeTCP { + if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil { + errs = append(errs, err) + } + if err := l.forwarder.Close(); err != nil { + errs = append(errs, err) } } - return firstErr + return errors.Join(errs...) } func (l *lifecycle) IsActive() bool { - return l.status == types.RUNNING + return l.status == types.SessionStatusRUNNING } func (l *lifecycle) StartedAt() time.Time { diff --git a/session/session.go b/session/session.go index be9e9ed..b1895ab 100644 --- a/session/session.go +++ b/session/session.go @@ -12,6 +12,8 @@ import ( "tunnel_pls/internal/config" portUtil "tunnel_pls/internal/port" "tunnel_pls/internal/random" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" @@ -21,46 +23,40 @@ import ( "golang.org/x/crypto/ssh" ) -type Detail struct { - ForwardingType string `json:"forwarding_type,omitempty"` - Slug string `json:"slug,omitempty"` - UserID string `json:"user_id,omitempty"` - Active bool `json:"active,omitempty"` - StartedAt time.Time `json:"started_at,omitempty"` -} - type Session interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) + HandleGlobalRequest(ch <-chan *ssh.Request) error + HandleTCPIPForward(req *ssh.Request) error + HandleHTTPForward(req *ssh.Request, port uint16) error + HandleTCPForward(req *ssh.Request, addr string, port uint16) error Lifecycle() lifecycle.Lifecycle Interaction() interaction.Interaction Forwarder() forwarder.Forwarder Slug() slug.Slug - Detail() *Detail + Detail() *types.Detail Start() error } type session struct { + config config.Config initialReq <-chan *ssh.Request sshChan <-chan ssh.NewChannel lifecycle lifecycle.Lifecycle interaction interaction.Interaction forwarder forwarder.Forwarder slug slug.Slug - registry Registry + registry registry.Registry } var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session { +func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { slugManager := slug.New() - forwarderManager := forwarder.New(slugManager, conn) + forwarderManager := forwarder.New(config, slugManager, conn) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) - interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) + interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) return &session{ + config: config, initialReq: initialReq, sshChan: sshChan, lifecycle: lifecycleManager, @@ -87,16 +83,17 @@ func (s *session) Slug() slug.Slug { return s.slug } -func (s *session) Detail() *Detail { - var tunnelType string - if s.forwarder.TunnelType() == types.HTTP { - tunnelType = "HTTP" - } else if s.forwarder.TunnelType() == types.TCP { - tunnelType = "TCP" - } else { - tunnelType = "UNKNOWN" +func (s *session) Detail() *types.Detail { + tunnelTypeMap := map[types.TunnelType]string{ + types.TunnelTypeHTTP: "TunnelTypeHTTP", + types.TunnelTypeTCP: "TunnelTypeTCP", } - return &Detail{ + tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()] + if !ok { + tunnelType = "TunnelTypeUNKNOWN" + } + + return &types.Detail{ ForwardingType: tunnelType, Slug: s.slug.String(), UserID: s.lifecycle.User(), @@ -106,55 +103,80 @@ func (s *session) Detail() *Detail { } func (s *session) Start() error { - var channel ssh.NewChannel - var ok bool - select { - case channel, ok = <-s.sshChan: - if !ok { - log.Println("Forwarding request channel closed") - return nil - } - ch, reqs, err := channel.Accept() - if err != nil { - log.Printf("failed to accept channel: %v", err) - return err - } - go s.HandleGlobalRequest(reqs) - - s.lifecycle.SetChannel(ch) - s.interaction.SetChannel(ch) - s.interaction.SetMode(types.INTERACTIVE) - case <-time.After(500 * time.Millisecond): - s.interaction.SetMode(types.HEADLESS) + if err := s.setupSessionMode(); err != nil { + return err } tcpipReq := s.waitForTCPIPForward() if tcpipReq == nil { - err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) - if err != nil { - return err - } - if err = s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - return fmt.Errorf("no forwarding Request") + return s.handleMissingForwardRequest() } - if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" { - if err := tcpipReq.Reply(false, nil); err != nil { - log.Printf("cannot reply to tcpip req: %s\n", err) - return err - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - return err - } - return nil + if s.shouldRejectUnauthorized() { + return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode")) } - s.HandleTCPIPForward(tcpipReq) + if err := s.HandleTCPIPForward(tcpipReq); err != nil { + return err + } s.interaction.Start() + return s.waitForSessionEnd() +} + +func (s *session) setupSessionMode() error { + select { + case channel, ok := <-s.sshChan: + if !ok { + log.Println("Forwarding request channel closed") + return nil + } + return s.setupInteractiveMode(channel) + case <-time.After(500 * time.Millisecond): + s.interaction.SetMode(types.InteractiveModeHEADLESS) + return nil + } +} + +func (s *session) setupInteractiveMode(channel ssh.NewChannel) error { + ch, reqs, err := channel.Accept() + if err != nil { + log.Printf("failed to accept channel: %v", err) + return err + } + + go func() { + err = s.HandleGlobalRequest(reqs) + if err != nil { + log.Printf("global request handler error: %v", err) + } + }() + + s.lifecycle.SetChannel(ch) + s.interaction.SetChannel(ch) + s.interaction.SetMode(types.InteractiveModeINTERACTIVE) + + return nil +} + +func (s *session) handleMissingForwardRequest() error { + err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort())) + if err != nil { + return err + } + if err = s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + } + return fmt.Errorf("no forwarding Request") +} + +func (s *session) shouldRejectUnauthorized() bool { + return s.interaction.Mode() == types.InteractiveModeHEADLESS && + s.config.Mode() == types.ServerModeSTANDALONE && + s.lifecycle.User() == "UNAUTHORIZED" +} + +func (s *session) waitForSessionEnd() error { if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("ssh connection closed with error: %v", err) } @@ -187,220 +209,191 @@ func (s *session) waitForTCPIPForward() *ssh.Request { } } -func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { +func (s *session) handleWindowChange(req *ssh.Request) error { + p := req.Payload + if len(p) < 16 { + log.Println("invalid window-change payload") + return req.Reply(false, nil) + } + + cols := binary.BigEndian.Uint32(p[0:4]) + rows := binary.BigEndian.Uint32(p[4:8]) + + s.interaction.SetWH(int(cols), int(rows)) + return req.Reply(true, nil) +} + +func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { for req := range GlobalRequest { switch req.Type { case "shell", "pty-req": err := req.Reply(true, nil) if err != nil { - log.Println("Failed to reply to request:", err) - return + return err } case "window-change": - p := req.Payload - if len(p) < 16 { - log.Println("invalid window-change payload") - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - return - } - cols := binary.BigEndian.Uint32(p[0:4]) - rows := binary.BigEndian.Uint32(p[4:8]) - - s.interaction.SetWH(int(cols), int(rows)) - - err := req.Reply(true, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return + if err := s.handleWindowChange(req); err != nil { + return err } default: log.Println("Unknown request type:", req.Type) err := req.Reply(false, nil) if err != nil { - log.Println("Failed to reply to request:", err) - return + return err } } } + return nil } -func (s *session) HandleTCPIPForward(req *ssh.Request) { - log.Println("PortRegistry forwarding request detected") - - fail := func(msg string) { - log.Println(msg) - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - return - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - } - - reader := bytes.NewReader(req.Payload) - - addr, err := readSSHString(reader) +func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) { + address, err = readSSHString(payloadReader) if err != nil { - fail(fmt.Sprintf("Failed to read address from payload: %v", err)) - return + return "", 0, err } var rawPortToBind uint32 - if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { - fail(fmt.Sprintf("Failed to read port from payload: %v", err)) - return + if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil { + return "", 0, err } if rawPortToBind > 65535 { - fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind)) - return + return "", 0, fmt.Errorf("port is larger than allowed port of 65535") } - portToBind := uint16(rawPortToBind) - if isBlockedPort(portToBind) { - fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind)) - return + port = uint16(rawPortToBind) + if isBlockedPort(port) { + return "", 0, fmt.Errorf("port is block") } - switch portToBind { + if port == 0 { + unassigned, ok := s.lifecycle.PortRegistry().Unassigned() + if !ok { + return "", 0, fmt.Errorf("no available port") + } + return address, unassigned, err + } + + return address, port, err +} + +func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error { + var errs []error + if key != nil { + s.registry.Remove(*key) + } + if listener != nil { + if err := listener.Close(); err != nil { + errs = append(errs, fmt.Errorf("close listener: %w", err)) + } + } + if err := req.Reply(false, nil); err != nil { + errs = append(errs, fmt.Errorf("reply request: %w", err)) + } + if err := s.lifecycle.Close(); err != nil { + errs = append(errs, fmt.Errorf("close session: %w", err)) + } + errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg)) + return errors.Join(errs...) +} + +func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) { + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.BigEndian, uint32(port)) + if err != nil { + return err + } + + err = req.Reply(true, buf.Bytes()) + if err != nil { + return err + } + return nil +} + +func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error { + err := s.approveForwardingRequest(req, portToBind) + if err != nil { + return err + } + + s.forwarder.SetType(tunnelType) + s.forwarder.SetForwardedPort(portToBind) + s.slug.Set(slug) + s.lifecycle.SetStatus(types.SessionStatusRUNNING) + + if listener != nil { + s.forwarder.SetListener(listener) + } + + return nil +} + +func (s *session) HandleTCPIPForward(req *ssh.Request) error { + reader := bytes.NewReader(req.Payload) + + address, port, err := s.parseForwardPayload(reader) + if err != nil { + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error())) + } + + switch port { case 80, 443: - s.HandleHTTPForward(req, portToBind) + return s.HandleHTTPForward(req, port) default: - s.HandleTCPForward(req, addr, portToBind) + return s.HandleTCPForward(req, address, port) } } -func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) { - fail := func(msg string, key *types.SessionKey) { - log.Println(msg) - if key != nil { - s.registry.Remove(*key) - } - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - } - } - +func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { randomString, err := random.GenerateRandomString(20) if err != nil { - fail(fmt.Sprintf("Failed to create slug: %s", err), nil) - return + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) } - key := types.SessionKey{Id: randomString, Type: types.HTTP} + key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP} if !s.registry.Register(key, s) { - fail(fmt.Sprintf("Failed to register client with slug: %s", randomString), nil) - return + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString)) } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) + err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id) if err != nil { - fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key) - return + return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } - log.Printf("HTTP forwarding approved on port: %d", portToBind) - - err = req.Reply(true, buf.Bytes()) - if err != nil { - fail(fmt.Sprintf("Failed to reply to request: %v", err), &key) - return - } - - s.forwarder.SetType(types.HTTP) - s.forwarder.SetForwardedPort(portToBind) - s.slug.Set(randomString) - s.lifecycle.SetStatus(types.RUNNING) + return nil } -func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - fail := func(msg string) { - log.Println(msg) - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - return - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } +func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { + if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed { + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } - cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) { - log.Println(msg) - if key != nil { - s.registry.Remove(*key) - } - if port != 0 { - if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil { - log.Printf("Failed to reset port status: %v", setErr) - } - } - if listener != nil { - if closeErr := listener.Close(); closeErr != nil { - log.Printf("Failed to close listener: %v", closeErr) - } - } - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - } - _ = s.lifecycle.Close() - } - - if portToBind == 0 { - unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() - if !ok { - fail("No available port") - return - } - portToBind = unassigned - } - - if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed { - fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) - return - } - - log.Printf("Requested forwarding on %s:%d", addr, portToBind) - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) + tcpServer := transport.NewTCPServer(portToBind, s.forwarder) + listener, err := tcpServer.Listen() if err != nil { - cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil) - return + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } - key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} - + key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP} if !s.registry.Register(key, s) { - cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil) - return + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id)) } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) + err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id) if err != nil { - cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key) - return + return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } - log.Printf("TCP forwarding approved on port: %d", portToBind) - err = req.Reply(true, buf.Bytes()) - if err != nil { - cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key) - return - } + go func() { + err = tcpServer.Serve(listener) + if err != nil { + log.Printf("Failed serving tcp server: %s\n", err) + } + }() - s.forwarder.SetType(types.TCP) - s.forwarder.SetListener(listener) - s.forwarder.SetForwardedPort(portToBind) - s.slug.Set(key.Id) - s.lifecycle.SetStatus(types.RUNNING) - go s.forwarder.AcceptTCPConnections() + return nil } -func readSSHString(reader *bytes.Reader) (string, error) { +func readSSHString(reader io.Reader) (string, error) { var length uint32 if err := binary.Read(reader, binary.BigEndian, &length); err != nil { return "", err diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..277a293 --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1 @@ +sonar.projectKey=tunnel-please \ No newline at end of file diff --git a/types/types.go b/types/types.go index 148cd2b..34ccfb4 100644 --- a/types/types.go +++ b/types/types.go @@ -1,25 +1,34 @@ package types -type Status int +import "time" + +type SessionStatus int const ( - INITIALIZING Status = iota - RUNNING + SessionStatusINITIALIZING SessionStatus = iota + SessionStatusRUNNING ) -type Mode int +type InteractiveMode int const ( - INTERACTIVE Mode = iota - HEADLESS + InteractiveModeINTERACTIVE InteractiveMode = iota + 1 + InteractiveModeHEADLESS ) type TunnelType int const ( - UNKNOWN TunnelType = iota - HTTP - TCP + TunnelTypeUNKNOWN TunnelType = iota + TunnelTypeHTTP + TunnelTypeTCP +) + +type ServerMode int + +const ( + ServerModeSTANDALONE = iota + 1 + ServerModeNODE ) type SessionKey struct { @@ -27,6 +36,14 @@ type SessionKey struct { Type TunnelType } +type Detail struct { + ForwardingType string `json:"forwarding_type,omitempty"` + Slug string `json:"slug,omitempty"` + UserID string `json:"user_id,omitempty"` + Active bool `json:"active,omitempty"` + StartedAt time.Time `json:"started_at,omitempty"` +} + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" +