fix: correct logic when checking tcpip-forward request
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 5m34s

This commit is contained in:
2025-12-26 23:17:13 +07:00
parent 6dff735216
commit 76d1202b8e
7 changed files with 130 additions and 89 deletions

View File

@@ -10,8 +10,11 @@ import (
"net"
"regexp"
"strings"
"time"
"tunnel_pls/session"
"tunnel_pls/utils"
"golang.org/x/crypto/ssh"
)
type Interaction interface {
@@ -295,30 +298,38 @@ func Handler(conn net.Conn) {
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) {
payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr)
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
if closer, ok := cw.writer.(io.Closer); ok {
if closeErr := closer.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
var channel ssh.Channel
var reqs <-chan *ssh.Request
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer)
return
}
channel = result.channel
reqs = result.reqs
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer)
return
}
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in request handler goroutine: %v", r)
}
}()
for req := range reqs {
if err := req.Reply(false, nil); err != nil {
log.Printf("Failed to reply to request: %v", err)
return
}
}
}()
go ssh.DiscardRequests(reqs)
fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr)
@@ -329,14 +340,13 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
cw.reqHeader = initialRequest
for _, m := range cw.reqStartMW {
err = m.HandleRequest(cw.reqHeader)
if err != nil {
if err := m.HandleRequest(cw.reqHeader); err != nil {
log.Printf("Error handling request: %v", err)
return
}
}
_, err = channel.Write(initialRequest.Finalize())
_, err := channel.Write(initialRequest.Finalize())
if err != nil {
log.Printf("Failed to forward request: %v", err)
return

View File

@@ -23,20 +23,15 @@ func NewServer(config *ssh.ServerConfig) *Server {
return nil
}
if utils.Getenv("TLS_ENABLED", "false") == "true" {
go func() {
err = NewHTTPSServer()
if err != nil {
log.Fatalf("failed to start https server: %v", err)
}
return
}()
}
go func() {
err = NewHTTPServer()
err = NewHTTPSServer()
if err != nil {
log.Fatalf("failed to start http server: %v", err)
log.Fatalf("failed to start https server: %v", err)
}
}()
}
err = NewHTTPServer()
if err != nil {
log.Fatalf("failed to start http server: %v", err)
}
return &Server{
Conn: &listener,
Config: config,