fix: reject non tunnel request & reject duplicated port

This commit is contained in:
2025-04-08 23:14:42 +07:00
parent 5350bc13a9
commit 0117931817
6 changed files with 51 additions and 16 deletions

View File

@ -9,12 +9,16 @@ import (
func main() { func main() {
sshConfig := &ssh.ServerConfig{ sshConfig := &ssh.ServerConfig{
NoClientAuth: true, NoClientAuth: true,
ServerVersion: "SSH-2.0-TunnlPls-1.0",
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
return nil, nil return nil, nil
}, },
} }
log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
privateBytes, err := os.ReadFile("id_rsa") privateBytes, err := os.ReadFile("id_rsa")
if err != nil { if err != nil {
log.Fatal("Failed to load private key (./id_rsa)") log.Fatal("Failed to load private key (./id_rsa)")

View File

@ -1,7 +1,6 @@
package server package server
import ( import (
"fmt"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"log" "log"
"net" "net"
@ -16,7 +15,7 @@ func (s *Server) handleConnection(conn net.Conn) {
return return
} }
fmt.Println("SSH connection established:", sshConn.User()) log.Println("SSH connection established:", sshConn.User())
session.New(sshConn, chans, reqs) session.New(sshConn, chans, reqs)
} }

View File

@ -46,14 +46,14 @@ func Handler(conn net.Conn) {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
headers, err := peekUntilHeaders(reader, 8192) headers, err := peekUntilHeaders(reader, 8192)
if err != nil { if err != nil {
fmt.Println("Failed to peek headers:", err) log.Println("Failed to peek headers:", err)
return return
} }
host := strings.Split(parseHostFromHeader(headers), ".") host := strings.Split(parseHostFromHeader(headers), ".")
if len(host) < 1 { if len(host) < 1 {
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")) conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
fmt.Println("Bad Request") log.Println("Bad Request")
conn.Close() conn.Close()
return return
} }

View File

@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"golang.org/x/net/context" "golang.org/x/net/context"
"log" "log"
"net" "net"
@ -46,7 +45,7 @@ func HandlerTLS(conn net.Conn) {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
headers, err := peekUntilHeaders(reader, 8192) headers, err := peekUntilHeaders(reader, 8192)
if err != nil { if err != nil {
fmt.Println("Failed to peek headers:", err) log.Println("Failed to peek headers:", err)
return return
} }

View File

@ -16,7 +16,7 @@ type Server struct {
} }
func NewServer(config ssh.ServerConfig) *Server { func NewServer(config ssh.ServerConfig) *Server {
listener, err := net.Listen("tcp", ":2200") listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port")))
if err != nil { if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err) log.Fatalf("failed to listen on port 2200: %v", err)
return nil return nil
@ -45,7 +45,7 @@ func NewServer(config ssh.ServerConfig) *Server {
} }
func (s *Server) Start() { func (s *Server) Start() {
fmt.Println("SSH server is starting on port 2200...") log.Println("SSH server is starting on port 2200...")
for { for {
conn, err := (*s.Conn).Accept() conn, err := (*s.Conn).Accept()
if err != nil { if err != nil {

View File

@ -111,9 +111,11 @@ func (s *Session) Close() {
} }
func (s *Session) handleGlobalRequest() { func (s *Session) handleGlobalRequest() {
ticker := time.NewTicker(1 * time.Second)
for { for {
select { select {
case req := <-s.GlobalRequest: case req := <-s.GlobalRequest:
ticker.Stop()
if req == nil { if req == nil {
return return
} }
@ -124,6 +126,9 @@ func (s *Session) handleGlobalRequest() {
} }
case <-s.Done: case <-s.Done:
return return
case <-ticker.C:
s.sendMessage(fmt.Sprintf("Please specify the forwarding tunnel. For example: 'ssh %s -p %s -R 443:localhost:8080' \r\n\n\n", utils.Getenv("domain"), utils.Getenv("port")))
s.Close()
} }
} }
} }
@ -137,16 +142,30 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
if err != nil { if err != nil {
log.Println("Failed to read address from payload:", err) log.Println("Failed to read address from payload:", err)
req.Reply(false, nil) req.Reply(false, nil)
s.Close()
return return
} }
var portToBind uint32 var portToBind uint32
if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil { if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil {
log.Println("Failed to read port from payload:", err) log.Println("Failed to read port from payload:", err)
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
req.Reply(false, nil) req.Reply(false, nil)
s.Close()
return return
} }
if isBlockedPort(portToBind) {
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
req.Reply(false, nil)
s.Close()
return
}
s.sendMessage("\033[H\033[2J")
showWelcomeMessage(s.ConnChannels[0])
s.Status = RUNNING
if portToBind == 80 || portToBind == 443 { if portToBind == 80 || portToBind == 443 {
s.handleHTTPForward(req, portToBind) s.handleHTTPForward(req, portToBind)
return return
@ -155,6 +174,23 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
s.handleTCPForward(req, addr, portToBind) s.handleTCPForward(req, addr, portToBind)
} }
var blockedReservedPorts = []uint32{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func isBlockedPort(port uint32) bool {
if port == 80 || port == 443 {
return false
}
if port < 1024 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
return true
}
}
return false
}
func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) { func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) {
s.TunnelType = HTTP s.TunnelType = HTTP
s.ForwardedPort = uint16(portToBind) s.ForwardedPort = uint16(portToBind)
@ -190,13 +226,14 @@ func (s *Session) handleTCPForward(req *ssh.Request, addr string, portToBind uin
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {
log.Printf("Failed to bind to port %d: %v", portToBind, err) s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
req.Reply(false, nil) req.Reply(false, nil)
s.Close()
return return
} }
s.Listener = listener s.Listener = listener
s.ForwardedPort = uint16(portToBind) s.ForwardedPort = uint16(portToBind)
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort)) s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort))
go s.acceptTCPConnections() go s.acceptTCPConnections()
@ -466,7 +503,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
} }
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain)) s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))
} else { } else {
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, domain, s.ForwardedPort)) s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort))
} }
case "/slug": case "/slug":
@ -487,10 +524,6 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
} }
func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) { func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) {
connection.Write([]byte("\033[H\033[2J"))
showWelcomeMessage(connection)
s.Status = RUNNING
go s.handleGlobalRequest() go s.handleGlobalRequest()
for req := range requests { for req := range requests {