- Rename customWriter struct to httpWriter for clarity - Add closeWriter field to properly close write side of connections - Update all cw variable references to hw - Merge handlerTLS into handler function to reduce code duplication - Extract handler into smaller, focused methods - Split Read/Write/forwardRequest into composable functions Fixes resource leak where connections weren't properly closed on the write side, matching the forwarder's CloseWrite() pattern.
276 lines
6.5 KiB
Go
276 lines
6.5 KiB
Go
package server
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
"tunnel_pls/internal/config"
|
|
"tunnel_pls/session"
|
|
"tunnel_pls/types"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type HTTPServer interface {
|
|
ListenAndServe() error
|
|
ListenAndServeTLS() error
|
|
}
|
|
type httpServer struct {
|
|
sessionRegistry session.Registry
|
|
redirectTLS bool
|
|
}
|
|
|
|
func NewHTTPServer(sessionRegistry session.Registry, redirectTLS bool) HTTPServer {
|
|
return &httpServer{
|
|
sessionRegistry: sessionRegistry,
|
|
redirectTLS: redirectTLS,
|
|
}
|
|
}
|
|
|
|
func (hs *httpServer) ListenAndServe() error {
|
|
httpPort := config.Getenv("HTTP_PORT", "8080")
|
|
listener, err := net.Listen("tcp", ":"+httpPort)
|
|
if err != nil {
|
|
return errors.New("Error listening: " + err.Error())
|
|
}
|
|
go func() {
|
|
for {
|
|
var conn net.Conn
|
|
conn, err = listener.Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return
|
|
}
|
|
log.Printf("Error accepting connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
go hs.handler(conn, false)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (hs *httpServer) ListenAndServeTLS() error {
|
|
domain := config.Getenv("DOMAIN", "localhost")
|
|
httpsPort := config.Getenv("HTTPS_PORT", "8443")
|
|
|
|
tlsConfig, err := NewTLSConfig(domain)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize TLS config: %w", err)
|
|
}
|
|
|
|
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
for {
|
|
var conn net.Conn
|
|
conn, err = ln.Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
log.Println("https server closed")
|
|
}
|
|
log.Printf("Error accepting connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
go hs.handler(conn, true)
|
|
}
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (hs *httpServer) redirect(conn net.Conn, status int, location string) error {
|
|
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
|
|
fmt.Sprintf("Location: %s", location) +
|
|
"Content-Length: 0\r\n" +
|
|
"Connection: close\r\n" +
|
|
"\r\n"))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hs *httpServer) badRequest(conn net.Conn) error {
|
|
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hs *httpServer) handler(conn net.Conn, isTLS bool) {
|
|
defer hs.closeConnection(conn)
|
|
|
|
dstReader := bufio.NewReader(conn)
|
|
reqhf, err := NewRequestHeaderFactory(dstReader)
|
|
if err != nil {
|
|
log.Printf("Error creating request header: %v", err)
|
|
return
|
|
}
|
|
|
|
slug, err := hs.extractSlug(reqhf)
|
|
if err != nil {
|
|
_ = hs.badRequest(conn)
|
|
return
|
|
}
|
|
|
|
if hs.shouldRedirectToTLS(isTLS) {
|
|
_ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
|
|
return
|
|
}
|
|
|
|
if hs.handlePingRequest(slug, conn) {
|
|
return
|
|
}
|
|
|
|
sshSession, err := hs.getSession(slug)
|
|
if err != nil {
|
|
_ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
|
|
return
|
|
}
|
|
|
|
hw := NewHTTPWriter(conn, dstReader, conn.RemoteAddr())
|
|
hs.forwardRequest(hw, reqhf, sshSession)
|
|
}
|
|
|
|
func (hs *httpServer) closeConnection(conn net.Conn) {
|
|
err := conn.Close()
|
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
log.Printf("Error closing connection: %v", err)
|
|
}
|
|
}
|
|
|
|
func (hs *httpServer) extractSlug(reqhf RequestHeaderManager) (string, error) {
|
|
host := strings.Split(reqhf.Get("Host"), ".")
|
|
if len(host) < 1 {
|
|
return "", errors.New("invalid host")
|
|
}
|
|
return host[0], nil
|
|
}
|
|
|
|
func (hs *httpServer) shouldRedirectToTLS(isTLS bool) bool {
|
|
return !isTLS && hs.redirectTLS
|
|
}
|
|
|
|
func (hs *httpServer) handlePingRequest(slug string, conn net.Conn) bool {
|
|
if slug != "ping" {
|
|
return false
|
|
}
|
|
|
|
_, err := conn.Write([]byte(
|
|
"HTTP/1.1 200 OK\r\n" +
|
|
"Content-Length: 0\r\n" +
|
|
"Connection: close\r\n" +
|
|
"Access-Control-Allow-Origin: *\r\n" +
|
|
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
|
|
"Access-Control-Allow-Headers: *\r\n" +
|
|
"\r\n",
|
|
))
|
|
if err != nil {
|
|
log.Println("Failed to write 200 OK:", err)
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (hs *httpServer) getSession(slug string) (session.Session, error) {
|
|
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
|
|
Id: slug,
|
|
Type: types.HTTP,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return sshSession, nil
|
|
}
|
|
|
|
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
|
|
channel, err := hs.openForwardedChannel(hw, sshSession)
|
|
if err != nil {
|
|
log.Printf("Failed to establish channel: %v", err)
|
|
sshSession.Forwarder().WriteBadGatewayResponse(hw)
|
|
return
|
|
}
|
|
|
|
hs.setupMiddlewares(hw)
|
|
|
|
if err := hs.sendInitialRequest(hw, initialRequest, channel); err != nil {
|
|
log.Printf("Failed to forward initial request: %v", err)
|
|
return
|
|
}
|
|
|
|
sshSession.Forwarder().HandleConnection(hw, channel, hw.RemoteAddr())
|
|
}
|
|
|
|
func (hs *httpServer) openForwardedChannel(hw HTTPWriter, sshSession session.Session) (ssh.Channel, error) {
|
|
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
|
|
|
|
type channelResult struct {
|
|
channel ssh.Channel
|
|
reqs <-chan *ssh.Request
|
|
err error
|
|
}
|
|
|
|
resultChan := make(chan channelResult, 1)
|
|
|
|
go func() {
|
|
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
|
select {
|
|
case resultChan <- channelResult{channel, reqs, err}:
|
|
default:
|
|
hs.cleanupUnusedChannel(channel, reqs)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case result := <-resultChan:
|
|
if result.err != nil {
|
|
return nil, result.err
|
|
}
|
|
go ssh.DiscardRequests(result.reqs)
|
|
return result.channel, nil
|
|
case <-time.After(5 * time.Second):
|
|
return nil, errors.New("timeout opening forwarded-tcpip channel")
|
|
}
|
|
}
|
|
|
|
func (hs *httpServer) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
|
|
if channel != nil {
|
|
if err := channel.Close(); err != nil {
|
|
log.Printf("Failed to close unused channel: %v", err)
|
|
}
|
|
go ssh.DiscardRequests(reqs)
|
|
}
|
|
}
|
|
|
|
func (hs *httpServer) setupMiddlewares(hw HTTPWriter) {
|
|
fingerprintMiddleware := NewTunnelFingerprint()
|
|
forwardedForMiddleware := NewForwardedFor(hw.RemoteAddr())
|
|
|
|
hw.UseResponseMiddleware(fingerprintMiddleware)
|
|
hw.UseRequestMiddleware(forwardedForMiddleware)
|
|
}
|
|
|
|
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeaderManager, channel ssh.Channel) error {
|
|
hw.SetRequestHeader(initialRequest)
|
|
|
|
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
|
|
return fmt.Errorf("error applying request middlewares: %w", err)
|
|
}
|
|
|
|
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
|
|
return fmt.Errorf("error writing to channel: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|