diff --git a/go.mod b/go.mod index 31fdc54..09be3c3 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,8 @@ module tunnel_pls go 1.24.4 require ( - github.com/a-h/templ v0.3.833 github.com/joho/godotenv v1.5.1 - golang.org/x/crypto v0.32.0 - golang.org/x/net v0.33.0 + golang.org/x/crypto v0.45.0 ) -require ( - github.com/gorilla/websocket v1.5.3 // indirect - golang.org/x/sys v0.29.0 // indirect -) +require golang.org/x/sys v0.38.0 // indirect diff --git a/go.sum b/go.sum index e14b727..27269bf 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,13 @@ -github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU= -github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= diff --git a/server/http.go b/server/http.go index a69b836..3d9ac0f 100644 --- a/server/http.go +++ b/server/http.go @@ -298,8 +298,7 @@ func Handler(conn net.Conn) { } func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { - originHost, originPort := ParseAddr(cw.RemoteAddr.String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) + payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) diff --git a/server/server.go b/server/server.go index 0e6bdb6..9d01817 100644 --- a/server/server.go +++ b/server/server.go @@ -1,13 +1,10 @@ package server import ( - "bytes" - "encoding/binary" "fmt" "log" "net" "net/http" - "strconv" "tunnel_pls/utils" "golang.org/x/crypto/ssh" @@ -58,41 +55,3 @@ func (s *Server) Start() { go s.handleConnection(conn) } } - -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { - var buf bytes.Buffer - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - - return buf.Bytes() -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 41c9602..450184d 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -37,6 +37,7 @@ type ForwardingController interface { Close() error HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) SetLifecycle(lifecycle Lifecycle) + CreateForwardedTCPIPPayload(origin net.Addr) []byte } func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { @@ -53,8 +54,7 @@ func (f *Forwarder) AcceptTCPConnections() { log.Printf("Error accepting connection: %v", err) continue } - originHost, originPort := ParseAddr(conn.RemoteAddr().String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort()) + payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) @@ -129,33 +129,18 @@ func (f *Forwarder) Close() error { return nil } -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) -} -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { +func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { var buf bytes.Buffer + host, originPort := parseAddr(origin.String()) + writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) + err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort())) if err != nil { log.Printf("Failed to write string to buffer: %v", err) return nil } + writeSSHString(&buf, host) err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) if err != nil { @@ -165,3 +150,22 @@ func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { return buf.Bytes() } + +func parseAddr(addr string) (string, uint16) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint16(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint16(port) +} + +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) +} diff --git a/session/handler.go b/session/handler.go index 9123310..e2a77f7 100644 --- a/session/handler.go +++ b/session/handler.go @@ -212,7 +212,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.Forwarder.SetListener(listener) s.Forwarder.SetForwardedPort(portToBind) s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) go s.Forwarder.AcceptTCPConnections() diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 0f6c3ca..3f3db3f 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -324,7 +324,7 @@ func (i *Interaction) HandleCommand(command string) { } i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) } else { - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.Forwarder.GetTunnelType(), domain, i.Forwarder.GetForwardedPort())) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) } case "/slug": if i.Forwarder.GetTunnelType() != types.HTTP {