refactor: remove duplicate channel management helpers from HTTP handler
This commit is contained in:
@@ -2,6 +2,7 @@ package transport
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -19,8 +20,6 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var openChannelTimeout = 5 * time.Second
|
||||
|
||||
type httpHandler struct {
|
||||
domain string
|
||||
sessionRegistry registry.Registry
|
||||
@@ -139,13 +138,17 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
||||
}
|
||||
|
||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||
channel, err := hh.openForwardedChannel(hw, sshSession)
|
||||
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to establish channel: %v", err)
|
||||
sshSession.Forwarder().WriteBadGatewayResponse(hw)
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
defer func() {
|
||||
err = channel.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
@@ -162,51 +165,6 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ
|
||||
sshSession.Forwarder().HandleConnection(hw, channel)
|
||||
}
|
||||
|
||||
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
|
||||
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
hh.cleanupUnusedChannel(channel, reqs)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
return nil, result.err
|
||||
}
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
return result.channel, nil
|
||||
case <-time.After(openChannelTimeout):
|
||||
go func() {
|
||||
result := <-resultChan
|
||||
hh.cleanupUnusedChannel(result.channel, result.reqs)
|
||||
}()
|
||||
return nil, errors.New("timeout opening forwarded-tcpip channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
|
||||
if channel != nil {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Printf("Failed to close unused channel: %v", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
|
||||
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
|
||||
fingerprintMiddleware := middleware.NewTunnelFingerprint()
|
||||
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
|
||||
|
||||
Reference in New Issue
Block a user