Files
tunnel-please/server/http.go
bagas 27f49879af
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m43s
refactor(server): enhance HTTP handler modularity and fix resource leak
- Rename customWriter struct to httpWriter for clarity
- Add closeWriter field to properly close write side of connections
- Update all cw variable references to hw
- Merge handlerTLS into handler function to reduce code duplication
- Extract handler into smaller, focused methods
- Split Read/Write/forwardRequest into composable functions

Fixes resource leak where connections weren't properly closed on the
write side, matching the forwarder's CloseWrite() pattern.
2026-01-19 22:41:04 +07:00

276 lines
6.5 KiB
Go

package server
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/session"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type HTTPServer interface {
ListenAndServe() error
ListenAndServeTLS() error
}
type httpServer struct {
sessionRegistry session.Registry
redirectTLS bool
}
func NewHTTPServer(sessionRegistry session.Registry, redirectTLS bool) HTTPServer {
return &httpServer{
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hs *httpServer) ListenAndServe() error {
httpPort := config.Getenv("HTTP_PORT", "8080")
listener, err := net.Listen("tcp", ":"+httpPort)
if err != nil {
return errors.New("Error listening: " + err.Error())
}
go func() {
for {
var conn net.Conn
conn, err = listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn, false)
}
}()
return nil
}
func (hs *httpServer) ListenAndServeTLS() error {
domain := config.Getenv("DOMAIN", "localhost")
httpsPort := config.Getenv("HTTPS_PORT", "8443")
tlsConfig, err := NewTLSConfig(domain)
if err != nil {
return fmt.Errorf("failed to initialize TLS config: %w", err)
}
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
if err != nil {
return err
}
go func() {
for {
var conn net.Conn
conn, err = ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("https server closed")
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn, true)
}
}()
return nil
}
func (hs *httpServer) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
return err
}
return nil
}
func (hs *httpServer) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
}
func (hs *httpServer) handler(conn net.Conn, isTLS bool) {
defer hs.closeConnection(conn)
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
slug, err := hs.extractSlug(reqhf)
if err != nil {
_ = hs.badRequest(conn)
return
}
if hs.shouldRedirectToTLS(isTLS) {
_ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
return
}
if hs.handlePingRequest(slug, conn) {
return
}
sshSession, err := hs.getSession(slug)
if err != nil {
_ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
hw := NewHTTPWriter(conn, dstReader, conn.RemoteAddr())
hs.forwardRequest(hw, reqhf, sshSession)
}
func (hs *httpServer) closeConnection(conn net.Conn) {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
}
}
func (hs *httpServer) extractSlug(reqhf RequestHeaderManager) (string, error) {
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hs *httpServer) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hs.redirectTLS
}
func (hs *httpServer) handlePingRequest(slug string, conn net.Conn) bool {
if slug != "ping" {
return false
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
}
return true
}
func (hs *httpServer) getSession(slug string) (session.Session, error) {
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
return nil, err
}
return sshSession, nil
}
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
channel, err := hs.openForwardedChannel(hw, sshSession)
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
hs.setupMiddlewares(hw)
if err := hs.sendInitialRequest(hw, initialRequest, channel); err != nil {
log.Printf("Failed to forward initial request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(hw, channel, hw.RemoteAddr())
}
func (hs *httpServer) openForwardedChannel(hw HTTPWriter, sshSession session.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hs.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hs *httpServer) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hs *httpServer) setupMiddlewares(hw HTTPWriter) {
fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(hw.RemoteAddr())
hw.UseResponseMiddleware(fingerprintMiddleware)
hw.UseRequestMiddleware(forwardedForMiddleware)
}
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeaderManager, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
return fmt.Errorf("error applying request middlewares: %w", err)
}
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
return fmt.Errorf("error writing to channel: %w", err)
}
return nil
}