feat: close connection if no tunneling request is specified
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 7m18s

This commit is contained in:
2025-10-20 12:28:52 +00:00
parent 659b2b82ec
commit e02b7ed937
7 changed files with 82 additions and 66 deletions

View File

@ -92,7 +92,6 @@ func (s *Session) Close() error {
if s.Listener != nil {
err := s.Listener.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
fmt.Println("1")
return err
}
}
@ -100,7 +99,6 @@ func (s *Session) Close() error {
if s.ConnChannel != nil {
err := s.ConnChannel.Close()
if err != nil && !errors.Is(err, io.EOF) {
fmt.Println("2")
return err
}
}
@ -108,8 +106,6 @@ func (s *Session) Close() error {
if s.Connection != nil {
err := s.Connection.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
fmt.Println("3")
return err
}
}
@ -121,7 +117,6 @@ func (s *Session) Close() error {
if s.TunnelType == TCP {
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
if err != nil {
fmt.Println("4")
return err
}
}
@ -131,16 +126,35 @@ func (s *Session) Close() error {
}
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest {
switch req.Type {
case "tcpip-forward":
s.handleTCPIPForward(req)
return
case "shell", "pty-req", "window-change":
req.Reply(true, nil)
default:
log.Println("Unknown request type:", req.Type)
req.Reply(false, nil)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
for {
select {
case req, ok := <-GlobalRequest:
if !ok || req == nil {
log.Println("GlobalRequest channel closed")
return
}
switch req.Type {
case "tcpip-forward":
cancel()
s.handleTCPIPForward(req)
return
case "shell", "pty-req", "window-change":
req.Reply(true, nil)
default:
log.Println("Unknown request type:", req.Type)
req.Reply(false, nil)
}
case <-ctx.Done():
if s.Status == SETUP {
s.sendMessage("No forwarding request detected. See https://tunnl.live for setup help.\n\r")
err := s.Close()
if err != nil {
log.Println("Cannot close connection: ", err)
return
}
}
}
}
}
@ -216,15 +230,15 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
func isBlockedPort(SSH_PORT uint16) bool {
if SSH_PORT == 80 || SSH_PORT == 443 {
return false
}
if port < 1024 && port != 0 {
if SSH_PORT < 1024 && SSH_PORT != 0 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
if p == SSH_PORT {
return true
}
}
@ -250,13 +264,13 @@ func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint16) {
s.waitForRunningStatus()
domain := utils.Getenv("domain")
DOMAIN := utils.Getenv("DOMAIN")
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" {
protocol = "https"
}
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain))
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, DOMAIN))
req.Reply(true, buf.Bytes())
}
@ -403,13 +417,13 @@ func (s *Session) handleSlugEditMode(connection ssh.Channel, inSlugEditMode *boo
if len(*editSlug) > 0 {
*editSlug = (*editSlug)[:len(*editSlug)-1]
connection.Write([]byte("\r\033[K"))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
}
} else if char >= 32 && char <= 126 {
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' {
*editSlug += string(char)
connection.Write([]byte("\r\033[K"))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
}
}
}
@ -438,7 +452,7 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e
}
connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n"))
connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n"))
connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("DOMAIN") + "\r\n\r\n"))
connection.Write([]byte("Press any key to continue...\r\n"))
} else if isForbiddenSlug(*editSlug) {
connection.Write([]byte("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n"))
@ -457,12 +471,12 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e
connection.Write([]byte("\033[H\033[2J"))
showWelcomeMessage(connection)
domain := utils.Getenv("domain")
DOMAIN := utils.Getenv("DOMAIN")
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" {
protocol = "https"
}
connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain)))
connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN)))
*inSlugEditMode = false
commandBuffer.Reset()
@ -534,15 +548,15 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
case "/clear":
connection.Write([]byte("\033[H\033[2J"))
showWelcomeMessage(s.ConnChannel)
domain := utils.Getenv("domain")
DOMAIN := utils.Getenv("DOMAIN")
if s.TunnelType == HTTP {
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" {
protocol = "https"
}
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN))
} else {
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort))
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, DOMAIN, s.ForwardedPort))
}
case "/slug":
@ -553,7 +567,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
*editSlug = s.Slug
connection.Write([]byte("\033[H\033[2J"))
displaySlugEditor(connection, s.Slug)
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
}
default:
connection.Write([]byte("\r\nUnknown command"))
@ -684,8 +698,8 @@ func showWelcomeMessage(connection ssh.Channel) {
}
func displaySlugEditor(connection ssh.Channel, currentSlug string) {
domain := utils.Getenv("domain")
fullDomain := currentSlug + "." + domain
DOMAIN := utils.Getenv("DOMAIN")
fullDomain := currentSlug + "." + DOMAIN
const paddingRight = 4
@ -742,15 +756,15 @@ func ParseAddr(addr string) (string, uint32) {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint32(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint32(port)
SSH_PORT, _ := strconv.Atoi(portStr)
return host, uint32(SSH_PORT)
}
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
func createForwardedTCPIPPayload(host string, originPort, SSH_PORT uint16) []byte {
var buf bytes.Buffer
writeSSHString(&buf, "localhost")
binary.Write(&buf, binary.BigEndian, uint32(port))
binary.Write(&buf, binary.BigEndian, uint32(SSH_PORT))
writeSSHString(&buf, host)
binary.Write(&buf, binary.BigEndian, uint32(originPort))