fix: reject non tunnel request & reject duplicated port
This commit is contained in:
4
main.go
4
main.go
@ -10,11 +10,15 @@ import (
|
|||||||
func main() {
|
func main() {
|
||||||
sshConfig := &ssh.ServerConfig{
|
sshConfig := &ssh.ServerConfig{
|
||||||
NoClientAuth: true,
|
NoClientAuth: true,
|
||||||
|
ServerVersion: "SSH-2.0-TunnlPls-1.0",
|
||||||
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.SetOutput(os.Stdout)
|
||||||
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||||
|
|
||||||
privateBytes, err := os.ReadFile("id_rsa")
|
privateBytes, err := os.ReadFile("id_rsa")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Failed to load private key (./id_rsa)")
|
log.Fatal("Failed to load private key (./id_rsa)")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -16,7 +15,7 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("SSH connection established:", sshConn.User())
|
log.Println("SSH connection established:", sshConn.User())
|
||||||
|
|
||||||
session.New(sshConn, chans, reqs)
|
session.New(sshConn, chans, reqs)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -46,14 +46,14 @@ func Handler(conn net.Conn) {
|
|||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
headers, err := peekUntilHeaders(reader, 8192)
|
headers, err := peekUntilHeaders(reader, 8192)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed to peek headers:", err)
|
log.Println("Failed to peek headers:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
host := strings.Split(parseHostFromHeader(headers), ".")
|
host := strings.Split(parseHostFromHeader(headers), ".")
|
||||||
if len(host) < 1 {
|
if len(host) < 1 {
|
||||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||||
fmt.Println("Bad Request")
|
log.Println("Bad Request")
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,7 +4,6 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -46,7 +45,7 @@ func HandlerTLS(conn net.Conn) {
|
|||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
headers, err := peekUntilHeaders(reader, 8192)
|
headers, err := peekUntilHeaders(reader, 8192)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Failed to peek headers:", err)
|
log.Println("Failed to peek headers:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ type Server struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(config ssh.ServerConfig) *Server {
|
func NewServer(config ssh.ServerConfig) *Server {
|
||||||
listener, err := net.Listen("tcp", ":2200")
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to listen on port 2200: %v", err)
|
log.Fatalf("failed to listen on port 2200: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@ -45,7 +45,7 @@ func NewServer(config ssh.ServerConfig) *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Start() {
|
func (s *Server) Start() {
|
||||||
fmt.Println("SSH server is starting on port 2200...")
|
log.Println("SSH server is starting on port 2200...")
|
||||||
for {
|
for {
|
||||||
conn, err := (*s.Conn).Accept()
|
conn, err := (*s.Conn).Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -111,9 +111,11 @@ func (s *Session) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) handleGlobalRequest() {
|
func (s *Session) handleGlobalRequest() {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case req := <-s.GlobalRequest:
|
case req := <-s.GlobalRequest:
|
||||||
|
ticker.Stop()
|
||||||
if req == nil {
|
if req == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -124,6 +126,9 @@ func (s *Session) handleGlobalRequest() {
|
|||||||
}
|
}
|
||||||
case <-s.Done:
|
case <-s.Done:
|
||||||
return
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.sendMessage(fmt.Sprintf("Please specify the forwarding tunnel. For example: 'ssh %s -p %s -R 443:localhost:8080' \r\n\n\n", utils.Getenv("domain"), utils.Getenv("port")))
|
||||||
|
s.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -137,16 +142,30 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to read address from payload:", err)
|
log.Println("Failed to read address from payload:", err)
|
||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
|
s.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var portToBind uint32
|
var portToBind uint32
|
||||||
if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil {
|
if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil {
|
||||||
log.Println("Failed to read port from payload:", err)
|
log.Println("Failed to read port from payload:", err)
|
||||||
|
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
|
s.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isBlockedPort(portToBind) {
|
||||||
|
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
||||||
|
req.Reply(false, nil)
|
||||||
|
s.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendMessage("\033[H\033[2J")
|
||||||
|
showWelcomeMessage(s.ConnChannels[0])
|
||||||
|
s.Status = RUNNING
|
||||||
|
|
||||||
if portToBind == 80 || portToBind == 443 {
|
if portToBind == 80 || portToBind == 443 {
|
||||||
s.handleHTTPForward(req, portToBind)
|
s.handleHTTPForward(req, portToBind)
|
||||||
return
|
return
|
||||||
@ -155,6 +174,23 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
|||||||
s.handleTCPForward(req, addr, portToBind)
|
s.handleTCPForward(req, addr, portToBind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var blockedReservedPorts = []uint32{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||||
|
|
||||||
|
func isBlockedPort(port uint32) bool {
|
||||||
|
if port == 80 || port == 443 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if port < 1024 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, p := range blockedReservedPorts {
|
||||||
|
if p == port {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) {
|
func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) {
|
||||||
s.TunnelType = HTTP
|
s.TunnelType = HTTP
|
||||||
s.ForwardedPort = uint16(portToBind)
|
s.ForwardedPort = uint16(portToBind)
|
||||||
@ -190,13 +226,14 @@ func (s *Session) handleTCPForward(req *ssh.Request, addr string, portToBind uin
|
|||||||
|
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to bind to port %d: %v", portToBind, err)
|
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
|
s.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.Listener = listener
|
s.Listener = listener
|
||||||
s.ForwardedPort = uint16(portToBind)
|
s.ForwardedPort = uint16(portToBind)
|
||||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort))
|
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort))
|
||||||
|
|
||||||
go s.acceptTCPConnections()
|
go s.acceptTCPConnections()
|
||||||
|
|
||||||
@ -466,7 +503,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
|
|||||||
}
|
}
|
||||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))
|
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))
|
||||||
} else {
|
} else {
|
||||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, domain, s.ForwardedPort))
|
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort))
|
||||||
}
|
}
|
||||||
|
|
||||||
case "/slug":
|
case "/slug":
|
||||||
@ -487,10 +524,6 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) {
|
func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) {
|
||||||
connection.Write([]byte("\033[H\033[2J"))
|
|
||||||
showWelcomeMessage(connection)
|
|
||||||
s.Status = RUNNING
|
|
||||||
|
|
||||||
go s.handleGlobalRequest()
|
go s.handleGlobalRequest()
|
||||||
|
|
||||||
for req := range requests {
|
for req := range requests {
|
||||||
|
|||||||
Reference in New Issue
Block a user