feat: add subdomain forwarding support for tunnel

This commit is contained in:
2025-02-07 03:26:01 +07:00
parent 8a1604fde8
commit 82eb7af7a6
2 changed files with 127 additions and 72 deletions

View File

@ -3,6 +3,7 @@ package session
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
@ -189,7 +190,7 @@ func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) {
case "shell", "pty-req", "window-change":
req.Reply(true, nil)
default:
fmt.Println("Unknown request type")
fmt.Println("Unknown request type of : ", req.Type)
req.Reply(false, nil)
}
}
@ -199,8 +200,8 @@ func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) {
func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerConn, port uint32) {
defer conn.Close()
log.Printf("Handling new forwarded connection from %s", conn.RemoteAddr())
payload := createForwardedTCPIPPayload(conn, port)
host, originPort := ParseAddr(conn.RemoteAddr().String())
payload := createForwardedTCPIPPayload(host, originPort, port)
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
@ -236,14 +237,6 @@ func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerCo
s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr\r\n"))
return
} else {
//if isHttp {
// response, err := http.ReadResponse(reader, nil)
// if err != nil {
// return
// }
// fmt.Println(response)
//}
io.Copy(conn, reader)
}
@ -254,24 +247,85 @@ func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerCo
}()
}
func (s *Session) GetForwardedConnection(host string, sshConn *ssh.ServerConn, payload []byte, originPort, port uint32) []byte {
fmt.Println("Here 1")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
channelPayload := createForwardedTCPIPPayload(host, originPort, port)
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", channelPayload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return nil
}
fmt.Println("Here 2")
defer channel.Close()
head := bytes.NewReader(payload)
go io.Copy(channel, head)
fmt.Println("Here 3")
go func() {
for req := range reqs {
req.Reply(false, nil)
}
}()
fmt.Println("Here 4")
var data bytes.Buffer
done := make(chan error, 1)
go func() {
io.Copy(&data, channel)
done <- err
}()
go func() {
var lastSize int
ticker := time.NewTicker(100)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
currentSize := data.Len()
fmt.Println("Size buffer:", currentSize)
if currentSize == lastSize && currentSize > 0 {
fmt.Println("Buffer size unchanged, closing channel...")
cancel()
return
}
lastSize = currentSize
}
}
}()
select {
case <-ctx.Done():
return data.Bytes()
case err = <-done:
return data.Bytes()
}
}
func writeSSHString(buffer *bytes.Buffer, str string) {
binary.Write(buffer, binary.BigEndian, uint32(len(str)))
buffer.WriteString(str)
}
func parseAddr(addr string) (string, int) {
func ParseAddr(addr string) (string, uint32) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Println("Failed to parse origin address:", err)
return "0.0.0.0", 0
return "0.0.0.0", uint32(0)
}
port, _ := strconv.Atoi(portStr)
return host, port
return host, uint32(port)
}
func createForwardedTCPIPPayload(conn net.Conn, port uint32) []byte {
func createForwardedTCPIPPayload(host string, originPort, port uint32) []byte {
var buf bytes.Buffer
host, originPort := parseAddr(conn.RemoteAddr().String())
writeSSHString(&buf, "localhost")
binary.Write(&buf, binary.BigEndian, uint32(port))