feat: add TLS support
This commit is contained in:
0
certs/cert.pem
Normal file
0
certs/cert.pem
Normal file
20
http/http.go
20
http/http.go
@ -127,15 +127,19 @@ func handleRequest(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Host == utils.Getenv("domain") {
|
writer := &tcpResponseWriter{
|
||||||
writer := &tcpResponseWriter{
|
conn: conn,
|
||||||
conn: conn,
|
header: make(http.Header),
|
||||||
header: make(http.Header),
|
status: http.StatusOK,
|
||||||
status: http.StatusOK,
|
}
|
||||||
}
|
|
||||||
fmt.Println(r.Pattern)
|
|
||||||
router.ServeHTTP(writer, r)
|
|
||||||
|
|
||||||
|
if r.Host == utils.Getenv("domain") {
|
||||||
|
router.ServeHTTP(writer, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if utils.Getenv("tls_enabled") == "false" {
|
||||||
|
http.Redirect(writer, r, fmt.Sprintf("https://%s%s", r.Host, r.URL.RequestURI()), http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
95
http/https.go
Normal file
95
http/https.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package httpServer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"tunnel_pls/session"
|
||||||
|
"tunnel_pls/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ListenTLS(config *tls.Config) {
|
||||||
|
server, err := tls.Listen("tcp", ":443", config)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer server.Close()
|
||||||
|
log.Println("Listening on :443")
|
||||||
|
for {
|
||||||
|
conn, err := server.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go handleRequestTLS(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleRequestTLS(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
var rawRequest string
|
||||||
|
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
r, err := http.ReadRequest(reader)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error reading request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writer := &tcpResponseWriter{
|
||||||
|
conn: conn,
|
||||||
|
header: make(http.Header),
|
||||||
|
status: http.StatusOK,
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Host == utils.Getenv("domain") {
|
||||||
|
router.ServeHTTP(writer, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slug := strings.Split(r.Host, ".")[0]
|
||||||
|
if slug == "" {
|
||||||
|
fmt.Println("Error parsing slug: ", r.Host)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshSession, ok := session.Clients[slug]
|
||||||
|
if !ok {
|
||||||
|
fmt.Println("Error finding ssh session: ", slug)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawRequest += fmt.Sprintf("%s %s %s\r\n", r.Method, r.URL.RequestURI(), r.Proto)
|
||||||
|
rawRequest += fmt.Sprintf("Host: %s\r\n", r.Host)
|
||||||
|
|
||||||
|
for k, v := range r.Header {
|
||||||
|
rawRequest += fmt.Sprintf("%s: %s\r\n", k, v[0])
|
||||||
|
}
|
||||||
|
rawRequest += "\r\n"
|
||||||
|
|
||||||
|
if r.Body != nil {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error reading request body:", err)
|
||||||
|
} else {
|
||||||
|
rawRequest += string(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte(rawRequest)
|
||||||
|
|
||||||
|
host, originPort := session.ParseAddr(conn.RemoteAddr().String())
|
||||||
|
sshSession.GetForwardedConnection(conn, host, sshSession.Connection, payload, originPort, 80, r.RequestURI, r.Method, r.Proto)
|
||||||
|
}
|
||||||
@ -1,12 +1,14 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
httpServer "tunnel_pls/http"
|
httpServer "tunnel_pls/http"
|
||||||
|
"tunnel_pls/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@ -22,6 +24,15 @@ func NewServer(config ssh.ServerConfig) *Server {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
go httpServer.Listen()
|
go httpServer.Listen()
|
||||||
|
if utils.Getenv("tls_enabled") == "true" {
|
||||||
|
cert, err := tls.LoadX509KeyPair(utils.Getenv("cert_loc"), utils.Getenv("key_loc"))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to load key pair:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||||
|
go httpServer.ListenTLS(tlsConfig)
|
||||||
|
}
|
||||||
return &Server{
|
return &Server{
|
||||||
Conn: &listener,
|
Conn: &listener,
|
||||||
Config: &config,
|
Config: &config,
|
||||||
|
|||||||
@ -45,6 +45,7 @@ func (s *Session) handleGlobalRequest() {
|
|||||||
|
|
||||||
if portToBind == 80 || portToBind == 443 {
|
if portToBind == 80 || portToBind == 443 {
|
||||||
s.TunnelType = HTTP
|
s.TunnelType = HTTP
|
||||||
|
s.ForwardedPort = uint16(portToBind)
|
||||||
var slug string
|
var slug string
|
||||||
for {
|
for {
|
||||||
slug = utils.GenerateRandomString(32)
|
slug = utils.GenerateRandomString(32)
|
||||||
@ -57,7 +58,11 @@ func (s *Session) handleGlobalRequest() {
|
|||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||||
log.Printf("Forwarding approved on port: %d", portToBind)
|
log.Printf("Forwarding approved on port: %d", portToBind)
|
||||||
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", slug, utils.Getenv("domain"))))
|
if utils.Getenv("tls_enabled") == "true" {
|
||||||
|
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to https://%s.%s \r\n", slug, utils.Getenv("domain"))))
|
||||||
|
} else {
|
||||||
|
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", slug, utils.Getenv("domain"))))
|
||||||
|
}
|
||||||
req.Reply(true, buf.Bytes())
|
req.Reply(true, buf.Bytes())
|
||||||
} else {
|
} else {
|
||||||
s.TunnelType = TCP
|
s.TunnelType = TCP
|
||||||
|
|||||||
@ -12,6 +12,7 @@ type Session struct {
|
|||||||
GlobalRequest <-chan *ssh.Request
|
GlobalRequest <-chan *ssh.Request
|
||||||
Listener net.Listener
|
Listener net.Listener
|
||||||
TunnelType TunnelType
|
TunnelType TunnelType
|
||||||
|
ForwardedPort uint16
|
||||||
Done chan bool
|
Done chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user