feat: close connection if no tunneling request is specified
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 7m18s
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 7m18s
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
package port
|
||||
package SSH_PORT
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -45,10 +45,10 @@ func (pm *PortManager) AddPortRange(startPort, endPort uint16) error {
|
||||
if startPort > endPort {
|
||||
return fmt.Errorf("start port cannot be greater than end port")
|
||||
}
|
||||
for port := startPort; port <= endPort; port++ {
|
||||
if _, exists := pm.ports[port]; !exists {
|
||||
pm.ports[port] = false
|
||||
pm.sortedPorts = append(pm.sortedPorts, port)
|
||||
for SSH_PORT := startPort; SSH_PORT <= endPort; SSH_PORT++ {
|
||||
if _, exists := pm.ports[SSH_PORT]; !exists {
|
||||
pm.ports[SSH_PORT] = false
|
||||
pm.sortedPorts = append(pm.sortedPorts, SSH_PORT)
|
||||
}
|
||||
}
|
||||
sort.Slice(pm.sortedPorts, func(i, j int) bool {
|
||||
@ -61,30 +61,30 @@ func (pm *PortManager) GetUnassignedPort() (uint16, bool) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
for _, port := range pm.sortedPorts {
|
||||
if !pm.ports[port] {
|
||||
pm.ports[port] = true
|
||||
return port, true
|
||||
for _, SSH_PORT := range pm.sortedPorts {
|
||||
if !pm.ports[SSH_PORT] {
|
||||
pm.ports[SSH_PORT] = true
|
||||
return SSH_PORT, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error {
|
||||
func (pm *PortManager) SetPortStatus(SSH_PORT uint16, assigned bool) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if _, exists := pm.ports[port]; !exists {
|
||||
return fmt.Errorf("port %d is not in the allowed range", port)
|
||||
if _, exists := pm.ports[SSH_PORT]; !exists {
|
||||
return fmt.Errorf("port %d is not in the allowed range", SSH_PORT)
|
||||
}
|
||||
pm.ports[port] = assigned
|
||||
pm.ports[SSH_PORT] = assigned
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *PortManager) GetPortStatus(port uint16) (bool, bool) {
|
||||
func (pm *PortManager) GetPortStatus(SSH_PORT uint16) (bool, bool) {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
|
||||
status, exists := pm.ports[port]
|
||||
status, exists := pm.ports[SSH_PORT]
|
||||
return status, exists
|
||||
}
|
||||
|
||||
5
main.go
5
main.go
@ -1,11 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
"log"
|
||||
"os"
|
||||
"tunnel_pls/server"
|
||||
"tunnel_pls/utils"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -20,7 +21,7 @@ func main() {
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
privateBytes, err := os.ReadFile(utils.Getenv("ssh_private_key"))
|
||||
privateBytes, err := os.ReadFile(utils.Getenv("SSH_PRIVATE_KEY"))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load private key : %s", err.Error())
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@ var allowedCors = make(map[string]bool)
|
||||
var isAllowedAllCors = false
|
||||
|
||||
func init() {
|
||||
corsList := utils.Getenv("cors_list")
|
||||
corsList := utils.Getenv("CORS_LIST")
|
||||
if corsList == "*" {
|
||||
isAllowedAllCors = true
|
||||
} else {
|
||||
@ -86,11 +86,11 @@ func NewHTTPServer() error {
|
||||
}
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", ":80")
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("HTTP_PORT")))
|
||||
if err != nil {
|
||||
return errors.New("Error listening: " + err.Error())
|
||||
}
|
||||
if utils.Getenv("tls_enabled") == "true" && utils.Getenv("tls_redirect") == "true" {
|
||||
if utils.Getenv("TLS_ENABLED") == "true" && utils.Getenv("TLS_ENABLED") == "true" {
|
||||
redirectTLS = true
|
||||
}
|
||||
go func() {
|
||||
@ -129,7 +129,7 @@ func Handler(conn net.Conn) {
|
||||
|
||||
if redirectTLS {
|
||||
conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
|
||||
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("domain")) +
|
||||
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("DOMAIN")) +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n"))
|
||||
|
||||
@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func NewHTTPSServer() error {
|
||||
cert, err := tls.LoadX509KeyPair(utils.Getenv("cert_loc"), utils.Getenv("key_loc"))
|
||||
cert, err := tls.LoadX509KeyPair(utils.Getenv("CERT_LOC"), utils.Getenv("KEY_LOC"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -2,11 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"tunnel_pls/utils"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
@ -16,12 +17,12 @@ type Server struct {
|
||||
}
|
||||
|
||||
func NewServer(config ssh.ServerConfig) *Server {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port")))
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("SSH_PORT")))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen on port 2200: %v", err)
|
||||
return nil
|
||||
}
|
||||
if utils.Getenv("tls_enabled") == "true" {
|
||||
if utils.Getenv("TLS_ENABLED") == "true" {
|
||||
go func() {
|
||||
err := NewHTTPSServer()
|
||||
if err != nil {
|
||||
|
||||
@ -92,7 +92,6 @@ func (s *Session) Close() error {
|
||||
if s.Listener != nil {
|
||||
err := s.Listener.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
fmt.Println("1")
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -100,7 +99,6 @@ func (s *Session) Close() error {
|
||||
if s.ConnChannel != nil {
|
||||
err := s.ConnChannel.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
fmt.Println("2")
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -108,8 +106,6 @@ func (s *Session) Close() error {
|
||||
if s.Connection != nil {
|
||||
err := s.Connection.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
fmt.Println("3")
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -121,7 +117,6 @@ func (s *Session) Close() error {
|
||||
if s.TunnelType == TCP {
|
||||
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
||||
if err != nil {
|
||||
fmt.Println("4")
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -131,9 +126,18 @@ func (s *Session) Close() error {
|
||||
}
|
||||
|
||||
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
for req := range GlobalRequest {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case req, ok := <-GlobalRequest:
|
||||
if !ok || req == nil {
|
||||
log.Println("GlobalRequest channel closed")
|
||||
return
|
||||
}
|
||||
switch req.Type {
|
||||
case "tcpip-forward":
|
||||
cancel()
|
||||
s.handleTCPIPForward(req)
|
||||
return
|
||||
case "shell", "pty-req", "window-change":
|
||||
@ -142,6 +146,16 @@ func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
log.Println("Unknown request type:", req.Type)
|
||||
req.Reply(false, nil)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
if s.Status == SETUP {
|
||||
s.sendMessage("No forwarding request detected. See https://tunnl.live for setup help.\n\r")
|
||||
err := s.Close()
|
||||
if err != nil {
|
||||
log.Println("Cannot close connection: ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,15 +230,15 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) {
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
func isBlockedPort(SSH_PORT uint16) bool {
|
||||
if SSH_PORT == 80 || SSH_PORT == 443 {
|
||||
return false
|
||||
}
|
||||
if port < 1024 && port != 0 {
|
||||
if SSH_PORT < 1024 && SSH_PORT != 0 {
|
||||
return true
|
||||
}
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
if p == SSH_PORT {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -250,13 +264,13 @@ func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint16) {
|
||||
|
||||
s.waitForRunningStatus()
|
||||
|
||||
domain := utils.Getenv("domain")
|
||||
DOMAIN := utils.Getenv("DOMAIN")
|
||||
protocol := "http"
|
||||
if utils.Getenv("tls_enabled") == "true" {
|
||||
if utils.Getenv("TLS_ENABLED") == "true" {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain))
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, DOMAIN))
|
||||
req.Reply(true, buf.Bytes())
|
||||
}
|
||||
|
||||
@ -403,13 +417,13 @@ func (s *Session) handleSlugEditMode(connection ssh.Channel, inSlugEditMode *boo
|
||||
if len(*editSlug) > 0 {
|
||||
*editSlug = (*editSlug)[:len(*editSlug)-1]
|
||||
connection.Write([]byte("\r\033[K"))
|
||||
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
|
||||
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
|
||||
}
|
||||
} else if char >= 32 && char <= 126 {
|
||||
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' {
|
||||
*editSlug += string(char)
|
||||
connection.Write([]byte("\r\033[K"))
|
||||
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
|
||||
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -438,7 +452,7 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e
|
||||
}
|
||||
|
||||
connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n"))
|
||||
connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n"))
|
||||
connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("DOMAIN") + "\r\n\r\n"))
|
||||
connection.Write([]byte("Press any key to continue...\r\n"))
|
||||
} else if isForbiddenSlug(*editSlug) {
|
||||
connection.Write([]byte("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n"))
|
||||
@ -457,12 +471,12 @@ func (s *Session) handleSlugSave(connection ssh.Channel, inSlugEditMode *bool, e
|
||||
connection.Write([]byte("\033[H\033[2J"))
|
||||
showWelcomeMessage(connection)
|
||||
|
||||
domain := utils.Getenv("domain")
|
||||
DOMAIN := utils.Getenv("DOMAIN")
|
||||
protocol := "http"
|
||||
if utils.Getenv("tls_enabled") == "true" {
|
||||
if utils.Getenv("TLS_ENABLED") == "true" {
|
||||
protocol = "https"
|
||||
}
|
||||
connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain)))
|
||||
connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN)))
|
||||
|
||||
*inSlugEditMode = false
|
||||
commandBuffer.Reset()
|
||||
@ -534,15 +548,15 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
|
||||
case "/clear":
|
||||
connection.Write([]byte("\033[H\033[2J"))
|
||||
showWelcomeMessage(s.ConnChannel)
|
||||
domain := utils.Getenv("domain")
|
||||
DOMAIN := utils.Getenv("DOMAIN")
|
||||
if s.TunnelType == HTTP {
|
||||
protocol := "http"
|
||||
if utils.Getenv("tls_enabled") == "true" {
|
||||
if utils.Getenv("TLS_ENABLED") == "true" {
|
||||
protocol = "https"
|
||||
}
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, domain))
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, s.Slug, DOMAIN))
|
||||
} else {
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, domain, s.ForwardedPort))
|
||||
s.sendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.TunnelType, DOMAIN, s.ForwardedPort))
|
||||
}
|
||||
|
||||
case "/slug":
|
||||
@ -553,7 +567,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd
|
||||
*editSlug = s.Slug
|
||||
connection.Write([]byte("\033[H\033[2J"))
|
||||
displaySlugEditor(connection, s.Slug)
|
||||
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("domain")))
|
||||
connection.Write([]byte("➤ " + *editSlug + "." + utils.Getenv("DOMAIN")))
|
||||
}
|
||||
default:
|
||||
connection.Write([]byte("\r\nUnknown command"))
|
||||
@ -684,8 +698,8 @@ func showWelcomeMessage(connection ssh.Channel) {
|
||||
}
|
||||
|
||||
func displaySlugEditor(connection ssh.Channel, currentSlug string) {
|
||||
domain := utils.Getenv("domain")
|
||||
fullDomain := currentSlug + "." + domain
|
||||
DOMAIN := utils.Getenv("DOMAIN")
|
||||
fullDomain := currentSlug + "." + DOMAIN
|
||||
|
||||
const paddingRight = 4
|
||||
|
||||
@ -742,15 +756,15 @@ func ParseAddr(addr string) (string, uint32) {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint32(0)
|
||||
}
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint32(port)
|
||||
SSH_PORT, _ := strconv.Atoi(portStr)
|
||||
return host, uint32(SSH_PORT)
|
||||
}
|
||||
|
||||
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
||||
func createForwardedTCPIPPayload(host string, originPort, SSH_PORT uint16) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
binary.Write(&buf, binary.BigEndian, uint32(port))
|
||||
binary.Write(&buf, binary.BigEndian, uint32(SSH_PORT))
|
||||
writeSSHString(&buf, host)
|
||||
binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type TunnelType string
|
||||
@ -44,7 +45,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
|
||||
ch, reqs, _ := channel.Accept()
|
||||
if session.ConnChannel == nil {
|
||||
session.ConnChannel = ch
|
||||
session.Status = RUNNING
|
||||
go session.HandleGlobalRequest(forwardingReq)
|
||||
}
|
||||
go session.HandleGlobalRequest(reqs)
|
||||
|
||||
Reference in New Issue
Block a user