fix: correct logic when checking tcpip-forward request
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 5m34s
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 5m34s
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
"tunnel_pls/utils"
|
||||
@@ -70,26 +71,52 @@ func (f *Forwarder) AcceptTCPConnections() {
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Printf("Failed to set connection deadline: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reply to request: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
go f.HandleConnection(conn, channel, conn.RemoteAddr())
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||
log.Printf("Failed to clear connection deadline: %v", err)
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
|
||||
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,9 +19,6 @@ var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 54
|
||||
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
for req := range GlobalRequest {
|
||||
switch req.Type {
|
||||
case "tcpip-forward":
|
||||
s.HandleTCPIPForward(req)
|
||||
return
|
||||
case "shell", "pty-req", "window-change":
|
||||
err := req.Reply(true, nil)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,11 +2,8 @@ package lifecycle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
@@ -41,7 +38,6 @@ func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
|
||||
|
||||
type SessionLifecycle interface {
|
||||
Close() error
|
||||
WaitForRunningStatus()
|
||||
SetStatus(status types.Status)
|
||||
GetConnection() ssh.Conn
|
||||
GetChannel() ssh.Channel
|
||||
@@ -62,33 +58,6 @@ func (l *Lifecycle) GetConnection() ssh.Conn {
|
||||
func (l *Lifecycle) SetStatus(status types.Status) {
|
||||
l.Status = status
|
||||
}
|
||||
func (l *Lifecycle) WaitForRunningStatus() {
|
||||
timeout := time.After(3 * time.Second)
|
||||
ticker := time.NewTicker(150 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
frames := []string{"-", "\\", "|", "/"}
|
||||
i := 0
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i]))
|
||||
i = (i + 1) % len(frames)
|
||||
if l.Status == types.RUNNING {
|
||||
l.Interaction.SendMessage("\r\033[K")
|
||||
return
|
||||
}
|
||||
case <-timeout:
|
||||
l.Interaction.SendMessage("\r\033[K")
|
||||
l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n")
|
||||
err := l.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
log.Println("Timeout waiting for session to start running")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lifecycle) Close() error {
|
||||
err := l.Forwarder.Close()
|
||||
|
||||
@@ -2,13 +2,15 @@ package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
"tunnel_pls/utils"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
@@ -30,8 +32,6 @@ type SSHSession struct {
|
||||
Interaction interaction.Controller
|
||||
Forwarder forwarder.ForwardingController
|
||||
SlugManager slug.Manager
|
||||
|
||||
channelOnce sync.Once
|
||||
}
|
||||
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
|
||||
@@ -71,20 +71,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
SlugManager: slugManager,
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
for channel := range sshChan {
|
||||
ch, reqs, err := channel.Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed to accept channel: %v", err)
|
||||
continue
|
||||
}
|
||||
session.channelOnce.Do(func() {
|
||||
once.Do(func() {
|
||||
session.Lifecycle.SetChannel(ch)
|
||||
session.Interaction.SetChannel(ch)
|
||||
session.Lifecycle.SetStatus(types.SETUP)
|
||||
go session.HandleGlobalRequest(forwardingReq)
|
||||
session.Lifecycle.WaitForRunningStatus()
|
||||
})
|
||||
|
||||
tcpipReq := session.waitForTCPIPForward(forwardingReq)
|
||||
if tcpipReq == nil {
|
||||
session.Interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200")))
|
||||
if err := session.Lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
session.HandleTCPIPForward(tcpipReq)
|
||||
})
|
||||
go session.HandleGlobalRequest(reqs)
|
||||
}
|
||||
if err := session.Lifecycle.Close(); err != nil {
|
||||
@@ -92,6 +99,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh.Request {
|
||||
select {
|
||||
case req, ok := <-forwardingReq:
|
||||
if !ok {
|
||||
log.Println("Forwarding request channel closed")
|
||||
return nil
|
||||
}
|
||||
if req.Type == "tcpip-forward" {
|
||||
return req
|
||||
}
|
||||
if err := req.Reply(false, nil); err != nil {
|
||||
log.Printf("Failed to reply to request: %v", err)
|
||||
}
|
||||
log.Printf("Expected tcpip-forward request, got: %s", req.Type)
|
||||
return nil
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
log.Println("No forwarding request received")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func updateClientSlug(oldSlug, newSlug string) bool {
|
||||
clientsMutex.Lock()
|
||||
defer clientsMutex.Unlock()
|
||||
|
||||
Reference in New Issue
Block a user