feat: close connection if no tunneling request is specified
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 7m18s

This commit is contained in:
2025-10-20 12:28:52 +00:00
parent 659b2b82ec
commit e02b7ed937
7 changed files with 82 additions and 66 deletions

View File

@ -1,4 +1,4 @@
package port
package SSH_PORT
import (
"fmt"
@ -45,10 +45,10 @@ func (pm *PortManager) AddPortRange(startPort, endPort uint16) error {
if startPort > endPort {
return fmt.Errorf("start port cannot be greater than end port")
}
for port := startPort; port <= endPort; port++ {
if _, exists := pm.ports[port]; !exists {
pm.ports[port] = false
pm.sortedPorts = append(pm.sortedPorts, port)
for SSH_PORT := startPort; SSH_PORT <= endPort; SSH_PORT++ {
if _, exists := pm.ports[SSH_PORT]; !exists {
pm.ports[SSH_PORT] = false
pm.sortedPorts = append(pm.sortedPorts, SSH_PORT)
}
}
sort.Slice(pm.sortedPorts, func(i, j int) bool {
@ -61,30 +61,30 @@ func (pm *PortManager) GetUnassignedPort() (uint16, bool) {
pm.mu.Lock()
defer pm.mu.Unlock()
for _, port := range pm.sortedPorts {
if !pm.ports[port] {
pm.ports[port] = true
return port, true
for _, SSH_PORT := range pm.sortedPorts {
if !pm.ports[SSH_PORT] {
pm.ports[SSH_PORT] = true
return SSH_PORT, true
}
}
return 0, false
}
func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error {
func (pm *PortManager) SetPortStatus(SSH_PORT uint16, assigned bool) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if _, exists := pm.ports[port]; !exists {
return fmt.Errorf("port %d is not in the allowed range", port)
if _, exists := pm.ports[SSH_PORT]; !exists {
return fmt.Errorf("port %d is not in the allowed range", SSH_PORT)
}
pm.ports[port] = assigned
pm.ports[SSH_PORT] = assigned
return nil
}
func (pm *PortManager) GetPortStatus(port uint16) (bool, bool) {
func (pm *PortManager) GetPortStatus(SSH_PORT uint16) (bool, bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
status, exists := pm.ports[port]
status, exists := pm.ports[SSH_PORT]
return status, exists
}

View File

@ -1,11 +1,12 @@
package main
import (
"golang.org/x/crypto/ssh"
"log"
"os"
"tunnel_pls/server"
"tunnel_pls/utils"
"golang.org/x/crypto/ssh"
)
func main() {
@ -20,7 +21,7 @@ func main() {
log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
privateBytes, err := os.ReadFile(utils.Getenv("ssh_private_key"))
privateBytes, err := os.ReadFile(utils.Getenv("SSH_PRIVATE_KEY"))
if err != nil {
log.Fatalf("Failed to load private key : %s", err.Error())
}

View File

@ -63,7 +63,7 @@ var allowedCors = make(map[string]bool)
var isAllowedAllCors = false
func init() {
corsList := utils.Getenv("cors_list")
corsList := utils.Getenv("CORS_LIST")
if corsList == "*" {
isAllowedAllCors = true
} else {
@ -86,11 +86,11 @@ func NewHTTPServer() error {
}
}
listener, err := net.Listen("tcp", ":80")
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("HTTP_PORT")))
if err != nil {
return errors.New("Error listening: " + err.Error())
}
if utils.Getenv("tls_enabled") == "true" && utils.Getenv("tls_redirect") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" && utils.Getenv("TLS_ENABLED") == "true" {
redirectTLS = true
}
go func() {
@ -129,7 +129,7 @@ func Handler(conn net.Conn) {
if redirectTLS {
conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("domain")) +
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("DOMAIN")) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))

View File

@ -16,7 +16,7 @@ import (
)
func NewHTTPSServer() error {
cert, err := tls.LoadX509KeyPair(utils.Getenv("cert_loc"), utils.Getenv("key_loc"))
cert, err := tls.LoadX509KeyPair(utils.Getenv("CERT_LOC"), utils.Getenv("KEY_LOC"))
if err != nil {
return err
}

View File

@ -2,11 +2,12 @@ package server
import (
"fmt"
"golang.org/x/crypto/ssh"
"log"
"net"
"net/http"
"tunnel_pls/utils"
"golang.org/x/crypto/ssh"
)
type Server struct {
@ -16,12 +17,12 @@ type Server struct {
}
func NewServer(config ssh.ServerConfig) *Server {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port")))
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("SSH_PORT")))
if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err)
return nil
}
if utils.Getenv("tls_enabled") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" {
go func() {
err := NewHTTPSServer()
if err != nil {

View File

@ -92,7 +92,6 @@ func (s *Session) Close() error {
if s.Listener != nil {
err := s.Listener.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
fmt.Println("1")
return err
}
}
@ -100,7 +99,6 @@ func (s *Session) Close() error {
if s.ConnChannel != nil {
err := s.ConnChannel.Close()
if err != nil && !errors.Is(err, io.EOF) {
fmt.Println("2")
return err
}
}
@ -108,8 +106,6 @@ func (s *Session) Close() error {
if s.Connection != nil {
err := s.Connection.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
fmt.Println("3")
return err
}
}
@ -121,7 +117,6 @@ func (s *Session) Close() error {
if s.TunnelType == TCP {
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
if err != nil {
fmt.Println("4")
return err
}
}
@ -131,16 +126,35 @@ func (s *Session) Close() error {
}
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest {
switch req.Type {
case "tcpip-forward":
s.handleTCPIPForward(req)
return
case "shell", "pty-req", "window-change":
req.Reply(true, nil)
default:
log.Println("Unknown request type:", req.Type)
req.Reply(false, nil)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
for {
select {
case req, ok := <-GlobalRequest:
if !ok || req == nil {
log.Println("GlobalRequest channel closed")
return
}
switch req.Type {
case "tcpip-forward":
cancel()
s.handleTCPIPForward(req)
return
case "shell", "pty-req", "window-change":
req.Reply(true, nil)
default:
log.Println("Unknown request type:", req.Type)
req.Reply(false, nil)
}
case <-ctx.Done():
if s.Status == SETUP {
s.sendMessage("No forwarding request detected. See https://tunnl.live for setup help.\n\r")
err := s.Close()
if err != nil {
log.Println("Cannot close connection: ", err)
return
}
}
}
}
}
@ -216,15 +230,15 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
func isBlockedPort(SSH_PORT uint16) bool {
if SSH_PORT == 80 || SSH_PORT == 443 {
return false
}
if port < 1024 && port != 0 {
if SSH_PORT < 1024 && SSH_PORT != 0 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
if p == SSH_PORT {
return true
}
}
@ -250,13 +264,13 @@ func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint16) {
s.waitForRunningStatus()
domain := utils.Getenv("domain")
DOMAIN := utils.Getenv("DOMAIN")
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" {
protocol = "https"
}
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain))
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, DOMAIN))
req.Reply(true, buf.Bytes())
}
@ -403,13 +417,13 @@ func (s *Session) handleSlugEditMode(connection ssh.Channel, inSlugEditMode *boo
if len(*editSlug) > 0 {
*editSlug = (*editSlug)[:len(*editSlug)-1]
connection.Write([]byte("\r\033[K"))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
}
} else if char >= 32 && char <= 126 {
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' {
*editSlug += string(char)
connection.Write([]byte("\r\033[K"))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
}
}
}
@ -438,7 +452,7 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e
}
connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n"))
connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n"))
connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("DOMAIN") + "\r\n\r\n"))
connection.Write([]byte("Press any key to continue...\r\n"))
} else if isForbiddenSlug(*editSlug) {
connection.Write([]byte("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n"))
@ -457,12 +471,12 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e
connection.Write([]byte("\033[H\033[2J"))
showWelcomeMessage(connection)
domain := utils.Getenv("domain")
DOMAIN := utils.Getenv("DOMAIN")
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" {
protocol = "https"
}
connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain)))
connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN)))
*inSlugEditMode = false
commandBuffer.Reset()
@ -534,15 +548,15 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
case "/clear":
connection.Write([]byte("\033[H\033[2J"))
showWelcomeMessage(s.ConnChannel)
domain := utils.Getenv("domain")
DOMAIN := utils.Getenv("DOMAIN")
if s.TunnelType == HTTP {
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
if utils.Getenv("TLS_ENABLED") == "true" {
protocol = "https"
}
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN))
} else {
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort))
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, DOMAIN, s.ForwardedPort))
}
case "/slug":
@ -553,7 +567,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
*editSlug = s.Slug
connection.Write([]byte("\033[H\033[2J"))
displaySlugEditor(connection, s.Slug)
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
}
default:
connection.Write([]byte("\r\nUnknown command"))
@ -684,8 +698,8 @@ func showWelcomeMessage(connection ssh.Channel) {
}
func displaySlugEditor(connection ssh.Channel, currentSlug string) {
domain := utils.Getenv("domain")
fullDomain := currentSlug + "." + domain
DOMAIN := utils.Getenv("DOMAIN")
fullDomain := currentSlug + "." + DOMAIN
const paddingRight = 4
@ -742,15 +756,15 @@ func ParseAddr(addr string) (string, uint32) {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint32(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint32(port)
SSH_PORT, _ := strconv.Atoi(portStr)
return host, uint32(SSH_PORT)
}
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
func createForwardedTCPIPPayload(host string, originPort, SSH_PORT uint16) []byte {
var buf bytes.Buffer
writeSSHString(&buf, "localhost")
binary.Write(&buf, binary.BigEndian, uint32(port))
binary.Write(&buf, binary.BigEndian, uint32(SSH_PORT))
writeSSHString(&buf, host)
binary.Write(&buf, binary.BigEndian, uint32(originPort))

View File

@ -1,9 +1,10 @@
package session
import (
"golang.org/x/crypto/ssh"
"net"
"sync"
"golang.org/x/crypto/ssh"
)
type TunnelType string
@ -44,7 +45,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
ch, reqs, _ := channel.Accept()
if session.ConnChannel == nil {
session.ConnChannel = ch
session.Status = RUNNING
go session.HandleGlobalRequest(forwardingReq)
}
go session.HandleGlobalRequest(reqs)