fix: remove timeouts from HTTP/HTTPS handlers and improve concurrency
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 3m40s
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 3m40s
This commit is contained in:
@ -3,6 +3,7 @@ package session
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -16,7 +17,6 @@ import (
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/net/context"
|
||||
"tunnel_pls/utils"
|
||||
)
|
||||
|
||||
@ -29,10 +29,8 @@ const (
|
||||
)
|
||||
|
||||
type UserConnection struct {
|
||||
Reader io.Reader
|
||||
Writer net.Conn
|
||||
Context context.Context
|
||||
Cancel context.CancelFunc
|
||||
Reader io.Reader
|
||||
Writer net.Conn
|
||||
}
|
||||
|
||||
var (
|
||||
@ -78,6 +76,13 @@ func updateClientSlug(oldSlug, newSlug string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Session) safeClose() {
|
||||
s.once.Do(func() {
|
||||
close(s.ChannelChan)
|
||||
close(s.Done)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) Close() {
|
||||
if s.Listener != nil {
|
||||
s.Listener.Close()
|
||||
@ -99,7 +104,7 @@ func (s *Session) Close() {
|
||||
portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
||||
}
|
||||
|
||||
close(s.Done)
|
||||
s.safeClose()
|
||||
}
|
||||
|
||||
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
@ -267,9 +272,8 @@ func (s *Session) acceptTCPConnections() {
|
||||
}
|
||||
|
||||
go s.HandleForwardedConnection(UserConnection{
|
||||
Reader: nil,
|
||||
Writer: conn,
|
||||
Context: context.Background(),
|
||||
Reader: nil,
|
||||
Writer: conn,
|
||||
}, s.Connection)
|
||||
}
|
||||
}
|
||||
@ -538,40 +542,105 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
|
||||
}
|
||||
defer channel.Close()
|
||||
|
||||
go handleChannelRequests(reqs, conn, channel)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Panic in request handler: %v", r)
|
||||
}
|
||||
}()
|
||||
for req := range reqs {
|
||||
req.Reply(false, nil)
|
||||
}
|
||||
}()
|
||||
|
||||
if conn.Reader == nil {
|
||||
conn.Reader = bufio.NewReader(conn.Writer)
|
||||
}
|
||||
|
||||
go io.Copy(channel, conn.Reader)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 2)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Panic in reader copy: %v", r)
|
||||
errChan <- fmt.Errorf("panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := io.Copy(channel, conn.Reader)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
||||
errChan <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
reader := bufio.NewReader(channel)
|
||||
_, err = reader.Peek(1)
|
||||
if err == io.EOF {
|
||||
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, "Could not forward request to the tunnel addr"))
|
||||
|
||||
peekDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := reader.Peek(1)
|
||||
peekDone <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-peekDone:
|
||||
if err == io.EOF {
|
||||
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
|
||||
sendBadGatewayResponse(conn.Writer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Printf("Error peeking channel data: %v", err)
|
||||
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
|
||||
sendBadGatewayResponse(conn.Writer)
|
||||
return
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr())
|
||||
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
|
||||
sendBadGatewayResponse(conn.Writer)
|
||||
conn.Writer.Close()
|
||||
channel.Close()
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
s.sendMessage(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp))
|
||||
s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp))
|
||||
|
||||
io.Copy(conn.Writer, reader)
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("Panic in writer copy: %v", r)
|
||||
errChan <- fmt.Errorf("panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
func handleChannelRequests(reqs <-chan *ssh.Request, conn UserConnection, channel ssh.Channel) {
|
||||
select {
|
||||
case <-reqs:
|
||||
for req := range reqs {
|
||||
req.Reply(false, nil)
|
||||
_, err := io.Copy(conn.Writer, reader)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error copying from channel to conn.Writer: %v", err)
|
||||
errChan <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
}()
|
||||
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
log.Printf("Connection error: %v", err)
|
||||
break
|
||||
}
|
||||
case <-conn.Context.Done():
|
||||
conn.Writer.Close()
|
||||
channel.Close()
|
||||
log.Println("Connection closed by timeout")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ package session
|
||||
import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type TunnelType string
|
||||
@ -24,6 +25,7 @@ type Session struct {
|
||||
Slug string
|
||||
ChannelChan chan ssh.NewChannel
|
||||
Done chan bool
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
|
||||
|
||||
Reference in New Issue
Block a user