package transport import ( "bytes" "context" "errors" "fmt" "io" "log" "net" "net/http" "strings" "time" "tunnel_pls/internal/config" "tunnel_pls/internal/http/header" "tunnel_pls/internal/http/stream" "tunnel_pls/internal/middleware" "tunnel_pls/internal/registry" "tunnel_pls/types" "golang.org/x/crypto/ssh" ) type httpHandler struct { config config.Config sessionRegistry registry.Registry } func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler { return &httpHandler{ config: config, sessionRegistry: sessionRegistry, } } func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error { _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + fmt.Sprintf("Location: %s", location) + "Content-Length: 0\r\n" + "Connection: close\r\n" + "\r\n")) if err != nil { return err } return nil } func (hh *httpHandler) badRequest(conn net.Conn) error { if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { return err } return nil } func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { defer hh.closeConnection(conn) _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) buf := make([]byte, hh.config.HeaderSize()) n, err := conn.Read(buf) if err != nil { _ = hh.badRequest(conn) return } if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 { _ = hh.badRequest(conn) return } _ = conn.SetReadDeadline(time.Time{}) reqhf, err := header.NewRequest(buf[:n]) if err != nil { log.Printf("Error creating request header: %v", err) _ = hh.badRequest(conn) return } slug, err := hh.extractSlug(reqhf) if err != nil { _ = hh.badRequest(conn) return } if hh.shouldRedirectToTLS(isTLS) { _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain())) return } if hh.handlePingRequest(slug, conn) { return } sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ Id: slug, Type: types.TunnelTypeHTTP, }) if err != nil { _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) return } hw := stream.New(conn, conn, conn.RemoteAddr()) defer func(hw stream.HTTP) { err = hw.Close() if err != nil { log.Printf("Error closing HTTP stream: %v", err) } }(hw) hh.forwardRequest(hw, reqhf, sshSession) } func (hh *httpHandler) closeConnection(conn net.Conn) { err := conn.Close() if err != nil && !errors.Is(err, net.ErrClosed) { log.Printf("Error closing connection: %v", err) } } func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { host := strings.Split(reqhf.Value("Host"), ".") if len(host) <= 1 { return "", errors.New("invalid host") } return host[0], nil } func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { return !isTLS && hh.config.TLSRedirect() } func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { if slug != "ping" { return false } _, err := conn.Write([]byte( "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + "Access-Control-Allow-Headers: *\r\n" + "\r\n", )) if err != nil { log.Println("Failed to write 200 OK:", err) return true } return true } func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) return } go ssh.DiscardRequests(reqs) defer func() { err = channel.Close() if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error closing forwarded channel: %v", err) } }() hh.setupMiddlewares(hw) if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil { log.Printf("Failed to forward initial request: %v", err) return } sshSession.Forwarder().HandleConnection(hw, channel) } func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { fingerprintMiddleware := middleware.NewTunnelFingerprint() forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) hw.UseResponseMiddleware(fingerprintMiddleware) hw.UseRequestMiddleware(forwardedForMiddleware) } func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error { hw.SetRequestHeader(initialRequest) if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { return fmt.Errorf("error applying request middlewares: %w", err) } if _, err := channel.Write(initialRequest.Finalize()); err != nil { return fmt.Errorf("error writing to channel: %w", err) } return nil }