refactor: move CreateForwardedTCPIPPayload to forwarder interface

This commit is contained in:
2025-12-05 13:49:33 +07:00
parent 7a31047bb9
commit 659f6c3ee7
7 changed files with 36 additions and 82 deletions

9
go.mod
View File

@ -3,13 +3,8 @@ module tunnel_pls
go 1.24.4 go 1.24.4
require ( require (
github.com/a-h/templ v0.3.833
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
golang.org/x/crypto v0.32.0 golang.org/x/crypto v0.45.0
golang.org/x/net v0.33.0
) )
require ( require golang.org/x/sys v0.38.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
golang.org/x/sys v0.29.0 // indirect
)

13
go.sum
View File

@ -1,16 +1,13 @@
github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU=
github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=

View File

@ -298,8 +298,7 @@ func Handler(conn net.Conn) {
} }
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) {
originHost, originPort := ParseAddr(cw.RemoteAddr.String()) payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr)
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort())
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil { if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err) log.Printf("Failed to open forwarded-tcpip channel: %v", err)

View File

@ -1,13 +1,10 @@
package server package server
import ( import (
"bytes"
"encoding/binary"
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http" "net/http"
"strconv"
"tunnel_pls/utils" "tunnel_pls/utils"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -58,41 +55,3 @@ func (s *Server) Start() {
go s.handleConnection(conn) go s.handleConnection(conn)
} }
} }
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
var buf bytes.Buffer
writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(port))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
}
writeSSHString(&buf, host)
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
}
return buf.Bytes()
}
func writeSSHString(buffer *bytes.Buffer, str string) {
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
}
buffer.WriteString(str)
}
func ParseAddr(addr string) (string, uint32) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint32(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint32(port)
}

View File

@ -37,6 +37,7 @@ type ForwardingController interface {
Close() error Close() error
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte
} }
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
@ -53,8 +54,7 @@ func (f *Forwarder) AcceptTCPConnections() {
log.Printf("Error accepting connection: %v", err) log.Printf("Error accepting connection: %v", err)
continue continue
} }
originHost, originPort := ParseAddr(conn.RemoteAddr().String()) payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort())
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil { if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err) log.Printf("Failed to open forwarded-tcpip channel: %v", err)
@ -129,33 +129,18 @@ func (f *Forwarder) Close() error {
return nil return nil
} }
func ParseAddr(addr string) (string, uint32) { func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint32(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint32(port)
}
func writeSSHString(buffer *bytes.Buffer, str string) {
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
}
buffer.WriteString(str)
}
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
var buf bytes.Buffer var buf bytes.Buffer
host, originPort := parseAddr(origin.String())
writeSSHString(&buf, "localhost") writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(port)) err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
if err != nil { if err != nil {
log.Printf("Failed to write string to buffer: %v", err) log.Printf("Failed to write string to buffer: %v", err)
return nil return nil
} }
writeSSHString(&buf, host) writeSSHString(&buf, host)
err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
if err != nil { if err != nil {
@ -165,3 +150,22 @@ func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
return buf.Bytes() return buf.Bytes()
} }
func parseAddr(addr string) (string, uint16) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint16(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint16(port)
}
func writeSSHString(buffer *bytes.Buffer, str string) {
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
}
buffer.WriteString(str)
}

View File

@ -212,7 +212,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
s.Forwarder.SetListener(listener) s.Forwarder.SetListener(listener)
s.Forwarder.SetForwardedPort(portToBind) s.Forwarder.SetForwardedPort(portToBind)
s.Interaction.ShowWelcomeMessage() s.Interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
go s.Forwarder.AcceptTCPConnections() go s.Forwarder.AcceptTCPConnections()

View File

@ -324,7 +324,7 @@ func (i *Interaction) HandleCommand(command string) {
} }
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))
} else { } else {
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.Forwarder.GetTunnelType(), domain, i.Forwarder.GetForwardedPort())) i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort()))
} }
case "/slug": case "/slug":
if i.Forwarder.GetTunnelType() != types.HTTP { if i.Forwarder.GetTunnelType() != types.HTTP {