diff --git a/server/http.go b/server/http.go index 9c00b28..f685d39 100644 --- a/server/http.go +++ b/server/http.go @@ -2,13 +2,12 @@ package server import ( "bufio" - "bytes" + "crypto/tls" "errors" "fmt" - "io" "log" "net" - "regexp" + "net/http" "strings" "time" "tunnel_pls/internal/config" @@ -18,214 +17,20 @@ import ( "golang.org/x/crypto/ssh" ) -type HTTPWriter interface { - io.Reader - io.Writer - GetRemoteAddr() net.Addr - GetWriter() io.Writer - AddResponseMiddleware(mw ResponseMiddleware) - AddRequestStartMiddleware(mw RequestMiddleware) - SetRequestHeader(header RequestHeaderManager) - GetRequestStartMiddleware() []RequestMiddleware -} - -type customWriter struct { - remoteAddr net.Addr - writer io.Writer - reader io.Reader - headerBuf []byte - buf []byte - respHeader ResponseHeaderManager - reqHeader RequestHeaderManager - respMW []ResponseMiddleware - reqStartMW []RequestMiddleware - reqEndMW []RequestMiddleware -} - -func (cw *customWriter) GetRemoteAddr() net.Addr { - return cw.remoteAddr -} - -func (cw *customWriter) GetWriter() io.Writer { - return cw.writer -} - -func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) { - cw.respMW = append(cw.respMW, mw) -} - -func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) { - cw.reqStartMW = append(cw.reqStartMW, mw) -} - -func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) { - cw.reqHeader = header -} - -func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware { - return cw.reqStartMW -} - -func (cw *customWriter) Read(p []byte) (int, error) { - tmp := make([]byte, len(p)) - read, err := cw.reader.Read(tmp) - if read == 0 && err != nil { - return 0, err - } - - tmp = tmp[:read] - - idx := bytes.Index(tmp, DELIMITER) - if idx == -1 { - copy(p, tmp) - if err != nil { - return read, err - } - return read, nil - } - - header := tmp[:idx+len(DELIMITER)] - body := tmp[idx+len(DELIMITER):] - - if !isHTTPHeader(header) { - copy(p, tmp) - return read, nil - } - - for _, m := range cw.reqEndMW { - err = m.HandleRequest(cw.reqHeader) - if err != nil { - log.Printf("Error when applying request middleware: %v", err) - return 0, err - } - } - - reqhf, err := NewRequestHeaderFactory(header) - if err != nil { - return 0, err - } - - for _, m := range cw.reqStartMW { - if mwErr := m.HandleRequest(reqhf); mwErr != nil { - log.Printf("Error when applying request middleware: %v", mwErr) - return 0, mwErr - } - } - - cw.reqHeader = reqhf - finalHeader := reqhf.Finalize() - - combined := append(finalHeader, body...) - - n := copy(p, combined) - - return n, nil -} - -func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { - return &customWriter{ - remoteAddr: remoteAddr, - writer: writer, - reader: reader, - buf: make([]byte, 0, 4096), - } -} - -var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} -var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) -var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) - -func isHTTPHeader(buf []byte) bool { - lines := bytes.Split(buf, []byte("\r\n")) - - startLine := string(lines[0]) - if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { - return false - } - - for _, line := range lines[1:] { - if len(line) == 0 { - break - } - colonIdx := bytes.IndexByte(line, ':') - if colonIdx <= 0 { - return false - } - } - return true -} - -func (cw *customWriter) Write(p []byte) (int, error) { - if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" { - cw.respHeader = nil - } - - if cw.respHeader != nil { - n, err := cw.writer.Write(p) - if err != nil { - return n, err - } - return n, nil - } - - cw.buf = append(cw.buf, p...) - - idx := bytes.Index(cw.buf, DELIMITER) - if idx == -1 { - return len(p), nil - } - - header := cw.buf[:idx+len(DELIMITER)] - body := cw.buf[idx+len(DELIMITER):] - - if !isHTTPHeader(header) { - _, err := cw.writer.Write(cw.buf) - cw.buf = nil - if err != nil { - return 0, err - } - return len(p), nil - } - - resphf := NewResponseHeaderFactory(header) - for _, m := range cw.respMW { - err := m.HandleResponse(resphf, body) - if err != nil { - log.Printf("Cannot apply middleware: %s\n", err) - return 0, err - } - } - header = resphf.Finalize() - cw.respHeader = resphf - - _, err := cw.writer.Write(header) - if err != nil { - return 0, err - } - if len(body) > 0 { - _, err = cw.writer.Write(body) - if err != nil { - return 0, err - } - } - cw.buf = nil - return len(p), nil -} - -var redirectTLS = false - type HTTPServer interface { ListenAndServe() error ListenAndServeTLS() error - handler(conn net.Conn) - handlerTLS(conn net.Conn) } type httpServer struct { sessionRegistry session.Registry + redirectTLS bool } -func NewHTTPServer(sessionRegistry session.Registry) HTTPServer { - return &httpServer{sessionRegistry: sessionRegistry} +func NewHTTPServer(sessionRegistry session.Registry, redirectTLS bool) HTTPServer { + return &httpServer{ + sessionRegistry: sessionRegistry, + redirectTLS: redirectTLS, + } } func (hs *httpServer) ListenAndServe() error { @@ -234,9 +39,6 @@ func (hs *httpServer) ListenAndServe() error { if err != nil { return errors.New("Error listening: " + err.Error()) } - if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" { - redirectTLS = true - } go func() { for { var conn net.Conn @@ -249,21 +51,65 @@ func (hs *httpServer) ListenAndServe() error { continue } - go hs.handler(conn) + go hs.handler(conn, false) } }() return nil } -func (hs *httpServer) handler(conn net.Conn) { - defer func() { - err := conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - log.Printf("Error closing connection: %v", err) - return +func (hs *httpServer) ListenAndServeTLS() error { + domain := config.Getenv("DOMAIN", "localhost") + httpsPort := config.Getenv("HTTPS_PORT", "8443") + + tlsConfig, err := NewTLSConfig(domain) + if err != nil { + return fmt.Errorf("failed to initialize TLS config: %w", err) + } + + ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig) + if err != nil { + return err + } + + go func() { + for { + var conn net.Conn + conn, err = ln.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + log.Println("https server closed") + } + log.Printf("Error accepting connection: %v", err) + continue + } + + go hs.handler(conn, true) } - return }() + return nil +} + +func (hs *httpServer) redirect(conn net.Conn, status int, location string) error { + _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + + fmt.Sprintf("Location: %s", location) + + "Content-Length: 0\r\n" + + "Connection: close\r\n" + + "\r\n")) + if err != nil { + return err + } + return nil +} + +func (hs *httpServer) badRequest(conn net.Conn) error { + if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + return err + } + return nil +} + +func (hs *httpServer) handler(conn net.Conn, isTLS bool) { + defer hs.closeConnection(conn) dstReader := bufio.NewReader(conn) reqhf, err := NewRequestHeaderFactory(dstReader) @@ -272,77 +118,108 @@ func (hs *httpServer) handler(conn net.Conn) { return } + slug, err := hs.extractSlug(reqhf) + if err != nil { + _ = hs.badRequest(conn) + return + } + + if hs.shouldRedirectToTLS(isTLS) { + _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) + return + } + + if hs.handlePingRequest(slug, conn) { + return + } + + sshSession, err := hs.getSession(slug) + if err != nil { + _ = hs.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) + return + } + + hw := NewHTTPWriter(conn, dstReader, conn.RemoteAddr()) + hs.forwardRequest(hw, reqhf, sshSession) +} + +func (hs *httpServer) closeConnection(conn net.Conn) { + err := conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + log.Printf("Error closing connection: %v", err) + } +} + +func (hs *httpServer) extractSlug(reqhf RequestHeaderManager) (string, error) { host := strings.Split(reqhf.Get("Host"), ".") if len(host) < 1 { - _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - if err != nil { - log.Println("Failed to write 400 Bad Request:", err) - return - } - return + return "", errors.New("invalid host") + } + return host[0], nil +} + +func (hs *httpServer) shouldRedirectToTLS(isTLS bool) bool { + return !isTLS && hs.redirectTLS +} + +func (hs *httpServer) handlePingRequest(slug string, conn net.Conn) bool { + if slug != "ping" { + return false } - slug := host[0] - - if redirectTLS { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) + + _, err := conn.Write([]byte( + "HTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return - } - - if slug == "ping" { - _, err = conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - return - } - return + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + + "Access-Control-Allow-Headers: *\r\n" + + "\r\n", + )) + if err != nil { + log.Println("Failed to write 200 OK:", err) } + return true +} +func (hs *httpServer) getSession(slug string) (session.Session, error) { sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ Id: slug, Type: types.HTTP, }) if err != nil { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return + return nil, err } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - forwardRequest(cw, reqhf, sshSession) - return + return sshSession, nil } -func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) +func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { + channel, err := hs.openForwardedChannel(hw, sshSession) + if err != nil { + log.Printf("Failed to establish channel: %v", err) + sshSession.Forwarder().WriteBadGatewayResponse(hw) + return + } + + hs.setupMiddlewares(hw) + + if err := hs.sendInitialRequest(hw, initialRequest, channel); err != nil { + log.Printf("Failed to forward initial request: %v", err) + return + } + + sshSession.Forwarder().HandleConnection(hw, channel, hw.RemoteAddr()) +} + +func (hs *httpServer) openForwardedChannel(hw HTTPWriter, sshSession session.Session) (ssh.Channel, error) { + payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) type channelResult struct { channel ssh.Channel reqs <-chan *ssh.Request err error } + resultChan := make(chan channelResult, 1) go func() { @@ -350,57 +227,49 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi select { case resultChan <- channelResult{channel, reqs, err}: default: - if channel != nil { - err := channel.Close() - if err != nil { - log.Printf("Failed to close unused channel: %v", err) - return - } - go ssh.DiscardRequests(reqs) - } + hs.cleanupUnusedChannel(channel, reqs) } }() - var channel ssh.Channel - var reqs <-chan *ssh.Request - select { case result := <-resultChan: if result.err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) - return + return nil, result.err } - channel = result.channel - reqs = result.reqs + go ssh.DiscardRequests(result.reqs) + return result.channel, nil case <-time.After(5 * time.Second): - log.Printf("Timeout opening forwarded-tcpip channel") - sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) - return + return nil, errors.New("timeout opening forwarded-tcpip channel") } - - go ssh.DiscardRequests(reqs) - - fingerprintMiddleware := NewTunnelFingerprint() - forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr()) - - cw.AddResponseMiddleware(fingerprintMiddleware) - cw.AddRequestStartMiddleware(forwardedForMiddleware) - cw.SetRequestHeader(initialRequest) - - for _, m := range cw.GetRequestStartMiddleware() { - if err := m.HandleRequest(initialRequest); err != nil { - log.Printf("Error handling request: %v", err) - return - } - } - - _, err := channel.Write(initialRequest.Finalize()) - if err != nil { - log.Printf("Failed to forward request: %v", err) - return - } - - sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) - return +} + +func (hs *httpServer) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { + if channel != nil { + if err := channel.Close(); err != nil { + log.Printf("Failed to close unused channel: %v", err) + } + go ssh.DiscardRequests(reqs) + } +} + +func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { + fingerprintMiddleware := NewTunnelFingerprint() + forwardedForMiddleware := NewForwardedFor(hw.RemoteAddr()) + + hw.UseResponseMiddleware(fingerprintMiddleware) + hw.UseRequestMiddleware(forwardedForMiddleware) +} + +func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeaderManager, channel ssh.Channel) error { + hw.SetRequestHeader(initialRequest) + + if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { + return fmt.Errorf("error applying request middlewares: %w", err) + } + + if _, err := channel.Write(initialRequest.Finalize()); err != nil { + return fmt.Errorf("error writing to channel: %w", err) + } + + return nil } diff --git a/server/https.go b/server/https.go deleted file mode 100644 index 2758172..0000000 --- a/server/https.go +++ /dev/null @@ -1,112 +0,0 @@ -package server - -import ( - "bufio" - "crypto/tls" - "errors" - "fmt" - "log" - "net" - "strings" - "tunnel_pls/internal/config" - "tunnel_pls/types" -) - -func (hs *httpServer) ListenAndServeTLS() error { - domain := config.Getenv("DOMAIN", "localhost") - httpsPort := config.Getenv("HTTPS_PORT", "8443") - - tlsConfig, err := NewTLSConfig(domain) - if err != nil { - return fmt.Errorf("failed to initialize TLS config: %w", err) - } - - ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig) - if err != nil { - return err - } - - go func() { - for { - var conn net.Conn - conn, err = ln.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - log.Println("https server closed") - } - log.Printf("Error accepting connection: %v", err) - continue - } - - go hs.handlerTLS(conn) - } - }() - return nil -} - -func (hs *httpServer) handlerTLS(conn net.Conn) { - defer func() { - err := conn.Close() - if err != nil { - log.Printf("Error closing connection: %v", err) - return - } - return - }() - - dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeaderFactory(dstReader) - if err != nil { - log.Printf("Error creating request header: %v", err) - return - } - - host := strings.Split(reqhf.Get("Host"), ".") - if len(host) < 1 { - _, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) - if err != nil { - log.Println("Failed to write 400 Bad Request:", err) - return - } - return - } - - slug := host[0] - - if slug == "ping" { - _, err = conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "Access-Control-Allow-Origin: *\r\n" + - "Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" + - "Access-Control-Allow-Headers: *\r\n" + - "\r\n", - )) - if err != nil { - log.Println("Failed to write 200 OK:", err) - return - } - return - } - - sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.HTTP, - }) - if err != nil { - _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + - "Content-Length: 0\r\n" + - "Connection: close\r\n" + - "\r\n")) - if err != nil { - log.Println("Failed to write 301 Moved Permanently:", err) - return - } - return - } - cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - forwardRequest(cw, reqhf, sshSession) - return -} diff --git a/server/httpwritter.go b/server/httpwritter.go new file mode 100644 index 0000000..bde7452 --- /dev/null +++ b/server/httpwritter.go @@ -0,0 +1,250 @@ +package server + +import ( + "bytes" + "io" + "log" + "net" + "regexp" +) + +type HTTPWriter interface { + io.ReadWriteCloser + CloseWrite() error + RemoteAddr() net.Addr + UseResponseMiddleware(mw ResponseMiddleware) + UseRequestMiddleware(mw RequestMiddleware) + SetRequestHeader(header RequestHeaderManager) + RequestMiddlewares() []RequestMiddleware + ResponseMiddlewares() []ResponseMiddleware + ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error + ApplyRequestMiddlewares(reqhf RequestHeaderManager) error +} + +type httpWriter struct { + remoteAddr net.Addr + writer io.Writer + reader io.Reader + headerBuf []byte + buf []byte + respHeader ResponseHeaderManager + reqHeader RequestHeaderManager + respMW []ResponseMiddleware + reqMW []RequestMiddleware +} + +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} +var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) +var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) + +func (hw *httpWriter) RemoteAddr() net.Addr { + return hw.remoteAddr +} + +func (hw *httpWriter) UseResponseMiddleware(mw ResponseMiddleware) { + hw.respMW = append(hw.respMW, mw) +} + +func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { + hw.reqMW = append(hw.reqMW, mw) +} + +func (hw *httpWriter) SetRequestHeader(header RequestHeaderManager) { + hw.reqHeader = header +} + +func (hw *httpWriter) RequestMiddlewares() []RequestMiddleware { + return hw.reqMW +} + +func (hw *httpWriter) ResponseMiddlewares() []ResponseMiddleware { + return hw.respMW +} +func (hw *httpWriter) Close() error { + return hw.writer.(io.Closer).Close() +} + +func (hw *httpWriter) CloseWrite() error { + if closer, ok := hw.writer.(interface{ CloseWrite() error }); ok { + return closer.CloseWrite() + } + return hw.Close() +} + +func (hw *httpWriter) Read(p []byte) (int, error) { + tmp := make([]byte, len(p)) + read, err := hw.reader.Read(tmp) + if read == 0 && err != nil { + return 0, err + } + + tmp = tmp[:read] + + headerEndIdx := bytes.Index(tmp, DELIMITER) + if headerEndIdx == -1 { + return hw.handleNoDelimiter(p, tmp, err) + } + + header, body := hw.splitHeaderAndBody(tmp, headerEndIdx) + + if !isHTTPHeader(header) { + copy(p, tmp) + return read, nil + } + + return hw.processHTTPRequest(p, header, body) +} + +func (hw *httpWriter) handleNoDelimiter(p, tmp []byte, err error) (int, error) { + copy(p, tmp) + return len(tmp), err +} + +func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) { + header := data[:delimiterIdx+len(DELIMITER)] + body := data[delimiterIdx+len(DELIMITER):] + return header, body +} + +func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { + reqhf, err := NewRequestHeaderFactory(header) + if err != nil { + return 0, err + } + + if err = hw.ApplyRequestMiddlewares(reqhf); err != nil { + return 0, err + } + + hw.reqHeader = reqhf + combined := append(reqhf.Finalize(), body...) + return copy(p, combined), nil +} + +func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeaderManager) error { + for _, m := range hw.RequestMiddlewares() { + if err := m.HandleRequest(reqhf); err != nil { + log.Printf("Error when applying request middleware: %v", err) + return err + } + } + return nil +} + +func (hw *httpWriter) Write(p []byte) (int, error) { + if hw.shouldBypassBuffering(p) { + hw.respHeader = nil + } + + if hw.respHeader != nil { + return hw.writer.Write(p) + } + + hw.buf = append(hw.buf, p...) + + headerEndIdx := bytes.Index(hw.buf, DELIMITER) + if headerEndIdx == -1 { + return len(p), nil + } + + return hw.processBufferedResponse(p, headerEndIdx) +} + +func (hw *httpWriter) shouldBypassBuffering(p []byte) bool { + return hw.respHeader != nil && len(hw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" +} + +func (hw *httpWriter) processBufferedResponse(p []byte, delimiterIdx int) (int, error) { + header, body := hw.splitHeaderAndBody(hw.buf, delimiterIdx) + + if !isHTTPHeader(header) { + return hw.writeRawBuffer() + } + + if err := hw.processHTTPResponse(header, body); err != nil { + return 0, err + } + + hw.buf = nil + return len(p), nil +} + +func (hw *httpWriter) writeRawBuffer() (int, error) { + _, err := hw.writer.Write(hw.buf) + length := len(hw.buf) + hw.buf = nil + if err != nil { + return 0, err + } + return length, nil +} + +func (hw *httpWriter) processHTTPResponse(header, body []byte) error { + resphf := NewResponseHeaderFactory(header) + + if err := hw.ApplyResponseMiddlewares(resphf, body); err != nil { + return err + } + + hw.respHeader = resphf + finalHeader := resphf.Finalize() + + if err := hw.writeHeaderAndBody(finalHeader, body); err != nil { + return err + } + + return nil +} + +func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error { + for _, m := range hw.ResponseMiddlewares() { + if err := m.HandleResponse(resphf, body); err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return err + } + } + return nil +} + +func (hw *httpWriter) writeHeaderAndBody(header, body []byte) error { + if _, err := hw.writer.Write(header); err != nil { + return err + } + + if len(body) > 0 { + if _, err := hw.writer.Write(body); err != nil { + return err + } + } + + return nil +} + +func NewHTTPWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { + return &httpWriter{ + remoteAddr: remoteAddr, + writer: writer, + reader: reader, + buf: make([]byte, 0, 4096), + } +} + +func isHTTPHeader(buf []byte) bool { + lines := bytes.Split(buf, []byte("\r\n")) + + startLine := string(lines[0]) + if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { + return false + } + + for _, line := range lines[1:] { + if len(line) == 0 { + break + } + colonIdx := bytes.IndexByte(line, ':') + if colonIdx <= 0 { + return false + } + } + return true +} diff --git a/server/server.go b/server/server.go index 792f47e..868b9e6 100644 --- a/server/server.go +++ b/server/server.go @@ -33,8 +33,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie log.Fatalf("failed to listen on port 2200: %v", err) return nil, err } + redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" - HttpServer := NewHTTPServer(sessionRegistry) + HttpServer := NewHTTPServer(sessionRegistry, redirectTLS) err = HttpServer.ListenAndServe() if err != nil { log.Fatalf("failed to start http server: %v", err) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index fcbc12f..fa1dff4 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -124,18 +124,6 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA if err != nil { log.Printf("Failed to discard connection: %v", err) } - - err = src.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing source channel: %v", err) - } - - if closer, ok := dst.(io.Closer); ok { - err = closer.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing destination connection: %v", err) - } - } }() log.Printf("Handling new forwarded connection from %s", remoteAddr)