Improve concurrency and resource management #2

Merged
bagas merged 12 commits from staging into main 2025-07-23 06:51:09 +00:00
7 changed files with 250 additions and 204 deletions
Showing only changes of commit 4b21541668 - Show all commits

View File

@ -560,37 +560,31 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, 2)
wg.Add(1)
go func() { go func() {
defer wg.Done()
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Printf("Panic in reader copy: %v", r) log.Printf("Panic in reader copy: %v", r)
errChan <- fmt.Errorf("panic: %v", r)
} }
cancel()
}() }()
_, err := io.Copy(channel, conn.Reader) _, err := io.Copy(channel, conn.Reader)
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error copying from conn.Reader to channel: %v", err) log.Printf("Error copying from conn.Reader to channel: %v", err)
errChan <- err
} }
cancel() cancel()
}() }()
reader := bufio.NewReader(channel) reader := bufio.NewReader(channel)
peekDone := make(chan error, 1) peekChan := make(chan error, 1)
go func() { go func() {
_, err := reader.Peek(1) _, err := reader.Peek(1)
peekDone <- err peekChan <- err
}() }()
select { select {
case err := <-peekDone: case err := <-peekChan:
if err == io.EOF { if err == io.EOF {
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
sendBadGatewayResponse(conn.Writer) sendBadGatewayResponse(conn.Writer)
@ -613,34 +607,9 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp))
wg.Add(1) _, err = io.Copy(conn.Writer, reader)
go func() { if err != nil && !errors.Is(err, io.EOF) {
defer wg.Done() log.Printf("Error copying from channel to conn.Writer: %v", err)
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in writer copy: %v", r)
errChan <- fmt.Errorf("panic: %v", r)
}
}()
_, err := io.Copy(conn.Writer, reader)
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error copying from channel to conn.Writer: %v", err)
errChan <- err
}
cancel()
}()
go func() {
wg.Wait()
close(errChan)
}()
for err := range errChan {
if err != nil {
log.Printf("Connection error: %v", err)
break
}
} }
} }