diff --git a/certs/cert.pem b/certs/cert.pem new file mode 100644 index 0000000..e69de29 diff --git a/http/http.go b/http/http.go index 4ba6c0b..ef5efcc 100644 --- a/http/http.go +++ b/http/http.go @@ -127,15 +127,19 @@ func handleRequest(conn net.Conn) { return } - if r.Host == utils.Getenv("domain") { - writer := &tcpResponseWriter{ - conn: conn, - header: make(http.Header), - status: http.StatusOK, - } - fmt.Println(r.Pattern) - router.ServeHTTP(writer, r) + writer := &tcpResponseWriter{ + conn: conn, + header: make(http.Header), + status: http.StatusOK, + } + if r.Host == utils.Getenv("domain") { + router.ServeHTTP(writer, r) + return + } + + if utils.Getenv("tls_enabled") == "false" { + http.Redirect(writer, r, fmt.Sprintf("https://%s%s", r.Host, r.URL.RequestURI()), http.StatusFound) return } diff --git a/http/https.go b/http/https.go new file mode 100644 index 0000000..3011ecf --- /dev/null +++ b/http/https.go @@ -0,0 +1,95 @@ +package httpServer + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "log" + "net" + "net/http" + "strings" + "tunnel_pls/session" + "tunnel_pls/utils" +) + +func ListenTLS(config *tls.Config) { + server, err := tls.Listen("tcp", ":443", config) + if err != nil { + return + } + + if err != nil { + log.Fatal(err) + return + } + + defer server.Close() + log.Println("Listening on :443") + for { + conn, err := server.Accept() + if err != nil { + log.Fatal(err) + return + } + + go handleRequestTLS(conn) + } +} + +func handleRequestTLS(conn net.Conn) { + defer conn.Close() + var rawRequest string + + reader := bufio.NewReader(conn) + r, err := http.ReadRequest(reader) + if err != nil { + fmt.Println("Error reading request:", err) + return + } + + writer := &tcpResponseWriter{ + conn: conn, + header: make(http.Header), + status: http.StatusOK, + } + + if r.Host == utils.Getenv("domain") { + router.ServeHTTP(writer, r) + return + } + + slug := strings.Split(r.Host, ".")[0] + if slug == "" { + fmt.Println("Error parsing slug: ", r.Host) + return + } + + sshSession, ok := session.Clients[slug] + if !ok { + fmt.Println("Error finding ssh session: ", slug) + return + } + + rawRequest += fmt.Sprintf("%s %s %s\r\n", r.Method, r.URL.RequestURI(), r.Proto) + rawRequest += fmt.Sprintf("Host: %s\r\n", r.Host) + + for k, v := range r.Header { + rawRequest += fmt.Sprintf("%s: %s\r\n", k, v[0]) + } + rawRequest += "\r\n" + + if r.Body != nil { + body, err := io.ReadAll(r.Body) + if err != nil { + log.Println("Error reading request body:", err) + } else { + rawRequest += string(body) + } + } + + payload := []byte(rawRequest) + + host, originPort := session.ParseAddr(conn.RemoteAddr().String()) + sshSession.GetForwardedConnection(conn, host, sshSession.Connection, payload, originPort, 80, r.RequestURI, r.Method, r.Proto) +} diff --git a/main b/main new file mode 100644 index 0000000..7853d5b Binary files /dev/null and b/main differ diff --git a/server/server.go b/server/server.go index 1c4e546..57c8278 100644 --- a/server/server.go +++ b/server/server.go @@ -1,12 +1,14 @@ package server import ( + "crypto/tls" "fmt" "golang.org/x/crypto/ssh" "log" "net" "net/http" httpServer "tunnel_pls/http" + "tunnel_pls/utils" ) type Server struct { @@ -22,6 +24,15 @@ func NewServer(config ssh.ServerConfig) *Server { return nil } go httpServer.Listen() + if utils.Getenv("tls_enabled") == "true" { + cert, err := tls.LoadX509KeyPair(utils.Getenv("cert_loc"), utils.Getenv("key_loc")) + if err != nil { + log.Fatal("Failed to load key pair:", err) + } + + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + go httpServer.ListenTLS(tlsConfig) + } return &Server{ Conn: &listener, Config: &config, diff --git a/session/handler.go b/session/handler.go index 395f7a5..8ca274d 100644 --- a/session/handler.go +++ b/session/handler.go @@ -45,6 +45,7 @@ func (s *Session) handleGlobalRequest() { if portToBind == 80 || portToBind == 443 { s.TunnelType = HTTP + s.ForwardedPort = uint16(portToBind) var slug string for { slug = utils.GenerateRandomString(32) @@ -57,7 +58,11 @@ func (s *Session) handleGlobalRequest() { buf := new(bytes.Buffer) binary.Write(buf, binary.BigEndian, uint32(portToBind)) log.Printf("Forwarding approved on port: %d", portToBind) - s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", slug, utils.Getenv("domain")))) + if utils.Getenv("tls_enabled") == "true" { + s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to https://%s.%s \r\n", slug, utils.Getenv("domain")))) + } else { + s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", slug, utils.Getenv("domain")))) + } req.Reply(true, buf.Bytes()) } else { s.TunnelType = TCP diff --git a/session/session.go b/session/session.go index adec069..3aa41ee 100644 --- a/session/session.go +++ b/session/session.go @@ -12,6 +12,7 @@ type Session struct { GlobalRequest <-chan *ssh.Request Listener net.Listener TunnelType TunnelType + ForwardedPort uint16 Done chan bool }