update: use raw TCP for HTTP server

This commit is contained in:
2025-04-02 23:27:39 +07:00
parent 58f15d5a67
commit 221adf9581
8 changed files with 232 additions and 545 deletions

View File

@ -13,10 +13,14 @@ import (
"net/http"
"strconv"
"time"
"tunnel_pls/proto"
"tunnel_pls/utils"
)
type UserConnection struct {
Reader io.Reader
Writer net.Conn
}
func (s *Session) handleGlobalRequest() {
for {
select {
@ -25,81 +29,8 @@ func (s *Session) handleGlobalRequest() {
return
}
if req.Type == "tcpip-forward" {
log.Println("Port forwarding request detected")
reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader)
if err != nil {
log.Println("Failed to read address from payload:", err)
req.Reply(false, nil)
continue
}
var portToBind uint32
if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil {
log.Println("Failed to read port from payload:", err)
req.Reply(false, nil)
continue
}
if portToBind == 80 || portToBind == 443 {
s.TunnelType = HTTP
s.ForwardedPort = uint16(portToBind)
var slug string
for {
slug = utils.GenerateRandomString(32)
if _, ok := Clients[slug]; ok {
continue
}
break
}
Clients[slug] = s
s.Slug = slug
buf := new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(portToBind))
log.Printf("Forwarding approved on port: %d", portToBind)
if utils.Getenv("tls_enabled") == "true" {
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to https://%s.%s \r\n", slug, utils.Getenv("domain"))))
} else {
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", slug, utils.Getenv("domain"))))
}
req.Reply(true, buf.Bytes())
} else {
s.TunnelType = TCP
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
log.Printf("Failed to bind to port %d: %v", portToBind, err)
req.Reply(false, nil)
continue
}
s.Listener = listener
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to %s:%d \r\n", utils.Getenv("domain"), portToBind)))
go func() {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go s.HandleForwardedConnection(conn, s.Connection, portToBind)
}
}()
buf := new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(portToBind))
log.Printf("Forwarding approved on port: %d", portToBind)
req.Reply(true, buf.Bytes())
}
s.handleTCPIPForward(req)
continue
} else {
req.Reply(false, nil)
}
@ -109,6 +40,88 @@ func (s *Session) handleGlobalRequest() {
}
}
func (s *Session) handleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected")
reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader)
if err != nil {
log.Println("Failed to read address from payload:", err)
req.Reply(false, nil)
return
}
var portToBind uint32
if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil {
log.Println("Failed to read port from payload:", err)
req.Reply(false, nil)
return
}
if portToBind == 80 || portToBind == 443 {
s.TunnelType = HTTP
s.ForwardedPort = uint16(portToBind)
var slug string
for {
slug = utils.GenerateRandomString(32)
if _, ok := Clients[slug]; ok {
return
}
break
}
Clients[slug] = s
s.Slug = slug
buf := new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(80))
log.Printf("Forwarding approved on port: %d", 80)
if utils.Getenv("tls_enabled") == "true" {
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to https://%s.%s \r\n", slug, utils.Getenv("domain"))))
} else {
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to http://%s.%s \r\n", slug, utils.Getenv("domain"))))
}
req.Reply(true, buf.Bytes())
} else {
s.TunnelType = TCP
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
log.Printf("Failed to bind to port %d: %v", portToBind, err)
req.Reply(false, nil)
return
}
s.Listener = listener
s.ConnChannels[0].Write([]byte(fmt.Sprintf("Forwarding your traffic to %s:%d \r\n", utils.Getenv("domain"), portToBind)))
go func() {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go s.HandleForwardedConnection(UserConnection{
Reader: nil,
Writer: conn,
}, s.Connection, portToBind)
}
}()
buf := new(bytes.Buffer)
binary.Write(buf, binary.BigEndian, uint32(portToBind))
log.Printf("Forwarding approved on port: %d", portToBind)
req.Reply(true, buf.Bytes())
}
}
func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) {
connection, requests, err := newChannel.Accept()
s.ConnChannels = append(s.ConnChannels, connection)
@ -213,70 +226,53 @@ func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) {
}()
}
func (s *Session) HandleForwardedConnection(conn net.Conn, sshConn *ssh.ServerConn, port uint32) {
defer conn.Close()
log.Printf("Handling new forwarded connection from %s", conn.RemoteAddr())
host, originPort := ParseAddr(conn.RemoteAddr().String())
func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn, port uint32) {
defer conn.Writer.Close()
log.Printf("Handling new forwarded connection from %s", conn.Writer.RemoteAddr())
host, originPort := ParseAddr(conn.Writer.RemoteAddr().String())
payload := createForwardedTCPIPPayload(host, originPort, port)
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
defer channel.Close()
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
connReader := bufio.NewReader(conn)
var isHttp bool
header, err := connReader.Peek(7)
if err != nil {
isHttp = false
} else {
isHttp = proto.IsHttpRequest(header)
}
conn.SetReadDeadline(time.Time{})
go io.Copy(channel, connReader)
reader := bufio.NewReader(channel)
_, err = reader.Peek(1)
if err == io.EOF {
if isHttp {
io.Copy(conn, bytes.NewReader([]byte("HTTP/1.1 502 Bad Gateway\r\nContent-Length: 11\r\nContent-Type: text/plain\r\n\r\nBad Gateway")))
} else {
conn.Write([]byte("Could not forward request to the tunnel addr\r\n"))
}
s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr\r\n"))
return
} else {
io.Copy(conn, reader)
}
go func() {
for req := range reqs {
req.Reply(false, nil)
}
}()
}
func (s *Session) GetForwardedConnection(conn net.Conn, host string, sshConn *ssh.ServerConn, payload []byte, originPort, port uint32, path, method, proto string) {
defer conn.Close()
channelPayload := createForwardedTCPIPPayload(host, originPort, port)
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", channelPayload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
defer channel.Close()
if conn.Reader == nil {
conn.Reader = bufio.NewReader(conn.Writer)
}
go io.Copy(channel, conn.Reader)
reader := bufio.NewReader(channel)
_, err = reader.Peek(1)
if err == io.EOF {
fmt.Println("error babi")
}
io.Copy(conn.Writer, reader)
}
connReader := bufio.NewReader(conn)
initalPayload := bytes.NewReader(payload)
io.Copy(channel, initalPayload)
go io.Copy(channel, connReader)
func (s *Session) HandleForwardedConnectionHTTP(conn net.Conn, sshConn *ssh.ServerConn, request *http.Request) {
defer conn.Close()
fmt.Println(request)
channelPayload := createForwardedTCPIPPayload(request.Host, 80, 80)
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", channelPayload)
go func() {
for req := range reqs {
req.Reply(false, nil)
}
}()
var requestBuffer bytes.Buffer
if err := request.Write(&requestBuffer); err != nil {
fmt.Println("Error serializing request:", err)
channel.Close()
conn.Close()
return
}
channel.Write(requestBuffer.Bytes())
reader := bufio.NewReader(channel)
_, err = reader.Peek(1)
@ -285,46 +281,42 @@ func (s *Session) GetForwardedConnection(conn net.Conn, host string, sshConn *ss
s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr\r\n"))
return
} else {
s.ConnChannels[0].Write([]byte(fmt.Sprintf("\033[32m %s -- [%s] \"%s %s %s\" \r\n \033[0m", host, time.Now().Format("02/Jan/2006 15:04:05"), method, path, proto)))
s.ConnChannels[0].Write([]byte(fmt.Sprintf("\033[32m %s -- [%s] \"%s %s %s\" \r\n \033[0m", request.Host, time.Now().Format("02/Jan/2006 15:04:05"), request.Method, request.RequestURI, request.Proto)))
io.Copy(conn, reader)
}
go func() {
for req := range reqs {
req.Reply(false, nil)
}
}()
}
func (s *Session) GetForwardedConnectionTLS(host string, sshConn *ssh.ServerConn, payload []byte, originPort, port uint32, path, method, proto string) *http.Response {
channelPayload := createForwardedTCPIPPayload(host, originPort, port)
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", channelPayload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return nil
}
defer channel.Close()
initalPayload := bytes.NewReader(payload)
io.Copy(channel, initalPayload)
go func() {
for req := range reqs {
req.Reply(false, nil)
}
}()
reader := bufio.NewReader(channel)
_, err = reader.Peek(1)
if err == io.EOF {
s.ConnChannels[0].Write([]byte("Could not forward request to the tunnel addr\r\n"))
return nil
} else {
s.ConnChannels[0].Write([]byte(fmt.Sprintf("\033[32m %s -- [%s] \"%s %s %s\" \r\n \033[0m", host, time.Now().Format("02/Jan/2006 15:04:05"), method, path, proto)))
response, _ := http.ReadResponse(reader, nil)
return response
}
}
//TODO: Implement HTTPS forwarding
//func (s *Session) GetForwardedConnectionTLS(host string, sshConn *ssh.ServerConn, payload []byte, originPort, port uint32, path, method, proto string) (*http.Response, error) {
// channelPayload := createForwardedTCPIPPayload(host, originPort, port)
// channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", channelPayload)
// if err != nil {
// return nil, err
// }
// defer channel.Close()
//
// initalPayload := bytes.NewReader(payload)
// io.Copy(channel, initalPayload)
//
// go func() {
// for req := range reqs {
// req.Reply(false, nil)
// }
// }()
//
// reader := bufio.NewReader(channel)
// _, err = reader.Peek(1)
// if err == io.EOF {
// return nil, err
// } else {
// s.ConnChannels[0].Write([]byte(fmt.Sprintf("\033[32m %s -- [%s] \"%s %s %s\" \r\n \033[0m", host, time.Now().Format("02/Jan/2006 15:04:05"), method, path, proto)))
// response, err := http.ReadResponse(reader, nil)
// if err != nil {
// return nil, err
// }
// return response, err
// }
//}
func writeSSHString(buffer *bytes.Buffer, str string) {
binary.Write(buffer, binary.BigEndian, uint32(len(str)))

View File

@ -8,15 +8,16 @@ import (
)
type Session struct {
ID uuid.UUID
Slug string
ConnChannels []ssh.Channel
Connection *ssh.ServerConn
GlobalRequest <-chan *ssh.Request
Listener net.Listener
TunnelType TunnelType
ForwardedPort uint16
Done chan bool
ID uuid.UUID
Slug string
ConnChannels []ssh.Channel
Connection *ssh.ServerConn
GlobalRequest <-chan *ssh.Request
Listener net.Listener
TunnelType TunnelType
ForwardedPort uint16
Done chan bool
ForwardedChannel ssh.Channel
}
type TunnelType string