refactor: move CreateForwardedTCPIPPayload to forwarder interface
This commit is contained in:
@ -37,6 +37,7 @@ type ForwardingController interface {
|
||||
Close() error
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||
@ -53,8 +54,7 @@ func (f *Forwarder) AcceptTCPConnections() {
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
|
||||
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort())
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
@ -129,33 +129,18 @@ func (f *Forwarder) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseAddr(addr string) (string, uint32) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint32(0)
|
||||
}
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint32(port)
|
||||
}
|
||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return
|
||||
}
|
||||
buffer.WriteString(str)
|
||||
}
|
||||
|
||||
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
||||
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
host, originPort := parseAddr(origin.String())
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(port))
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
writeSSHString(&buf, host)
|
||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||
if err != nil {
|
||||
@ -165,3 +150,22 @@ func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func parseAddr(addr string) (string, uint16) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint16(0)
|
||||
}
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint16(port)
|
||||
}
|
||||
|
||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return
|
||||
}
|
||||
buffer.WriteString(str)
|
||||
}
|
||||
|
||||
@ -212,7 +212,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
||||
s.Forwarder.SetListener(listener)
|
||||
s.Forwarder.SetForwardedPort(portToBind)
|
||||
s.Interaction.ShowWelcomeMessage()
|
||||
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
|
||||
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
|
||||
|
||||
go s.Forwarder.AcceptTCPConnections()
|
||||
|
||||
|
||||
@ -324,7 +324,7 @@ func (i *Interaction) HandleCommand(command string) {
|
||||
}
|
||||
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))
|
||||
} else {
|
||||
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.Forwarder.GetTunnelType(), domain, i.Forwarder.GetForwardedPort()))
|
||||
i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort()))
|
||||
}
|
||||
case "/slug":
|
||||
if i.Forwarder.GetTunnelType() != types.HTTP {
|
||||
|
||||
Reference in New Issue
Block a user