fix: correct logic when checking tcpip-forward request
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 5m34s

This commit is contained in:
2025-12-26 23:17:13 +07:00
parent 6dff735216
commit 76d1202b8e
7 changed files with 130 additions and 89 deletions

View File

@@ -9,6 +9,7 @@ import (
"net"
"strconv"
"sync"
"time"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"tunnel_pls/utils"
@@ -70,26 +71,52 @@ func (f *Forwarder) AcceptTCPConnections() {
log.Printf("Error accepting connection: %v", err)
continue
}
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Printf("Failed to set connection deadline: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
for req := range reqs {
err := req.Reply(false, nil)
if err != nil {
log.Printf("Failed to reply to request: %v", err)
return
}
}
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
go f.HandleConnection(conn, channel, conn.RemoteAddr())
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
if err := conn.SetDeadline(time.Time{}); err != nil {
log.Printf("Failed to clear connection deadline: %v", err)
}
go ssh.DiscardRequests(result.reqs)
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
}
}
}

View File

@@ -19,9 +19,6 @@ var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 54
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest {
switch req.Type {
case "tcpip-forward":
s.HandleTCPIPForward(req)
return
case "shell", "pty-req", "window-change":
err := req.Reply(true, nil)
if err != nil {

View File

@@ -2,11 +2,8 @@ package lifecycle
import (
"errors"
"fmt"
"io"
"log"
"net"
"time"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/session/slug"
"tunnel_pls/types"
@@ -41,7 +38,6 @@ func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
type SessionLifecycle interface {
Close() error
WaitForRunningStatus()
SetStatus(status types.Status)
GetConnection() ssh.Conn
GetChannel() ssh.Channel
@@ -62,33 +58,6 @@ func (l *Lifecycle) GetConnection() ssh.Conn {
func (l *Lifecycle) SetStatus(status types.Status) {
l.Status = status
}
func (l *Lifecycle) WaitForRunningStatus() {
timeout := time.After(3 * time.Second)
ticker := time.NewTicker(150 * time.Millisecond)
defer ticker.Stop()
frames := []string{"-", "\\", "|", "/"}
i := 0
for {
select {
case <-ticker.C:
l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i]))
i = (i + 1) % len(frames)
if l.Status == types.RUNNING {
l.Interaction.SendMessage("\r\033[K")
return
}
case <-timeout:
l.Interaction.SendMessage("\r\033[K")
l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n")
err := l.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
log.Println("Timeout waiting for session to start running")
return
}
}
}
func (l *Lifecycle) Close() error {
err := l.Forwarder.Close()

View File

@@ -2,13 +2,15 @@ package session
import (
"bytes"
"fmt"
"log"
"sync"
"time"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"tunnel_pls/utils"
"golang.org/x/crypto/ssh"
)
@@ -30,8 +32,6 @@ type SSHSession struct {
Interaction interaction.Controller
Forwarder forwarder.ForwardingController
SlugManager slug.Manager
channelOnce sync.Once
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
@@ -71,20 +71,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
SlugManager: slugManager,
}
var once sync.Once
for channel := range sshChan {
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
continue
}
session.channelOnce.Do(func() {
once.Do(func() {
session.Lifecycle.SetChannel(ch)
session.Interaction.SetChannel(ch)
session.Lifecycle.SetStatus(types.SETUP)
go session.HandleGlobalRequest(forwardingReq)
session.Lifecycle.WaitForRunningStatus()
})
tcpipReq := session.waitForTCPIPForward(forwardingReq)
if tcpipReq == nil {
session.Interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200")))
if err := session.Lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
session.HandleTCPIPForward(tcpipReq)
})
go session.HandleGlobalRequest(reqs)
}
if err := session.Lifecycle.Close(); err != nil {
@@ -92,6 +99,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
}
}
func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh.Request {
select {
case req, ok := <-forwardingReq:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
if req.Type == "tcpip-forward" {
return req
}
if err := req.Reply(false, nil); err != nil {
log.Printf("Failed to reply to request: %v", err)
}
log.Printf("Expected tcpip-forward request, got: %s", req.Type)
return nil
case <-time.After(500 * time.Millisecond):
log.Println("No forwarding request received")
return nil
}
}
func updateClientSlug(oldSlug, newSlug string) bool {
clientsMutex.Lock()
defer clientsMutex.Unlock()