fix: reject non tunnel request & reject duplicated port
This commit is contained in:
@ -111,9 +111,11 @@ func (s *Session) Close() {
|
||||
}
|
||||
|
||||
func (s *Session) handleGlobalRequest() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case req := <-s.GlobalRequest:
|
||||
ticker.Stop()
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
@ -124,6 +126,9 @@ func (s *Session) handleGlobalRequest() {
|
||||
}
|
||||
case <-s.Done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.sendMessage(fmt.Sprintf("Please specify the forwarding tunnel. For example: 'ssh %s -p %s -R 443:localhost:8080' \r\n\n\n", utils.Getenv("domain"), utils.Getenv("port")))
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -137,16 +142,30 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
||||
if err != nil {
|
||||
log.Println("Failed to read address from payload:", err)
|
||||
req.Reply(false, nil)
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var portToBind uint32
|
||||
if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil {
|
||||
log.Println("Failed to read port from payload:", err)
|
||||
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
||||
req.Reply(false, nil)
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if isBlockedPort(portToBind) {
|
||||
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
||||
req.Reply(false, nil)
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
s.sendMessage("\033[H\033[2J")
|
||||
showWelcomeMessage(s.ConnChannels[0])
|
||||
s.Status = RUNNING
|
||||
|
||||
if portToBind == 80 || portToBind == 443 {
|
||||
s.handleHTTPForward(req, portToBind)
|
||||
return
|
||||
@ -155,6 +174,23 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
||||
s.handleTCPForward(req, addr, portToBind)
|
||||
}
|
||||
|
||||
var blockedReservedPorts = []uint32{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func isBlockedPort(port uint32) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
}
|
||||
if port < 1024 {
|
||||
return true
|
||||
}
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) {
|
||||
s.TunnelType = HTTP
|
||||
s.ForwardedPort = uint16(portToBind)
|
||||
@ -190,13 +226,14 @@ func (s *Session) handleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||
if err != nil {
|
||||
log.Printf("Failed to bind to port %d: %v", portToBind, err)
|
||||
s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
||||
req.Reply(false, nil)
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
s.Listener = listener
|
||||
s.ForwardedPort = uint16(portToBind)
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort))
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, utils.Getenv("domain"), s.ForwardedPort))
|
||||
|
||||
go s.acceptTCPConnections()
|
||||
|
||||
@ -466,7 +503,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
|
||||
}
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))
|
||||
} else {
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%s \r\n", s.TunnelType, domain, s.ForwardedPort))
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort))
|
||||
}
|
||||
|
||||
case "/slug":
|
||||
@ -487,10 +524,6 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
|
||||
}
|
||||
|
||||
func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) {
|
||||
connection.Write([]byte("\033[H\033[2J"))
|
||||
showWelcomeMessage(connection)
|
||||
s.Status = RUNNING
|
||||
|
||||
go s.handleGlobalRequest()
|
||||
|
||||
for req := range requests {
|
||||
|
||||
Reference in New Issue
Block a user