Improve concurrency and resource management #2
@ -5,17 +5,14 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"golang.org/x/net/context"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var redirectTLS bool = false
|
var redirectTLS = false
|
||||||
|
|
||||||
func NewHTTPServer() error {
|
func NewHTTPServer() error {
|
||||||
listener, err := net.Listen("tcp", ":80")
|
listener, err := net.Listen("tcp", ":80")
|
||||||
@ -81,23 +78,10 @@ func Handler(conn net.Conn) {
|
|||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keepalive, timeout := parseConnectionDetails(headers)
|
|
||||||
var ctx context.Context
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if keepalive {
|
|
||||||
if timeout >= 300 {
|
|
||||||
timeout = 300
|
|
||||||
}
|
|
||||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second))
|
|
||||||
} else {
|
|
||||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
|
||||||
}
|
|
||||||
|
|
||||||
sshSession.HandleForwardedConnection(session.UserConnection{
|
sshSession.HandleForwardedConnection(session.UserConnection{
|
||||||
Reader: reader,
|
Reader: reader,
|
||||||
Writer: conn,
|
Writer: conn,
|
||||||
Context: ctx,
|
|
||||||
Cancel: cancel,
|
|
||||||
}, sshSession.Connection)
|
}, sshSession.Connection)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -131,42 +115,3 @@ func parseHostFromHeader(data []byte) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseConnectionDetails(data []byte) (keepAlive bool, timeout int) {
|
|
||||||
keepAlive = false
|
|
||||||
timeout = 30
|
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\r\n")
|
|
||||||
|
|
||||||
for _, line := range lines {
|
|
||||||
if strings.HasPrefix(strings.ToLower(line), "connection:") {
|
|
||||||
value := strings.TrimSpace(strings.TrimPrefix(strings.ToLower(line), "connection:"))
|
|
||||||
keepAlive = (value == "keep-alive")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if keepAlive {
|
|
||||||
for _, line := range lines {
|
|
||||||
if strings.HasPrefix(strings.ToLower(line), "keep-alive:") {
|
|
||||||
value := strings.TrimSpace(strings.TrimPrefix(line, "Keep-Alive:"))
|
|
||||||
|
|
||||||
if strings.Contains(value, "timeout=") {
|
|
||||||
parts := strings.Split(value, ",")
|
|
||||||
for _, part := range parts {
|
|
||||||
part = strings.TrimSpace(part)
|
|
||||||
if strings.HasPrefix(part, "timeout=") {
|
|
||||||
timeoutStr := strings.TrimPrefix(part, "timeout=")
|
|
||||||
if t, err := strconv.Atoi(timeoutStr); err == nil {
|
|
||||||
timeout = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return keepAlive, timeout
|
|
||||||
}
|
|
||||||
|
|||||||
@ -4,11 +4,9 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"golang.org/x/net/context"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
)
|
)
|
||||||
@ -70,23 +68,10 @@ func HandlerTLS(conn net.Conn) {
|
|||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keepalive, timeout := parseConnectionDetails(headers)
|
|
||||||
var ctx context.Context
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if keepalive {
|
|
||||||
if timeout >= 300 {
|
|
||||||
timeout = 300
|
|
||||||
}
|
|
||||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second))
|
|
||||||
} else {
|
|
||||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
|
||||||
}
|
|
||||||
|
|
||||||
sshSession.HandleForwardedConnection(session.UserConnection{
|
sshSession.HandleForwardedConnection(session.UserConnection{
|
||||||
Reader: reader,
|
Reader: reader,
|
||||||
Writer: conn,
|
Writer: conn,
|
||||||
Context: ctx,
|
|
||||||
Cancel: cancel,
|
|
||||||
}, sshSession.Connection)
|
}, sshSession.Connection)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package session
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -16,7 +17,6 @@ import (
|
|||||||
portUtil "tunnel_pls/internal/port"
|
portUtil "tunnel_pls/internal/port"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/net/context"
|
|
||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,8 +31,6 @@ const (
|
|||||||
type UserConnection struct {
|
type UserConnection struct {
|
||||||
Reader io.Reader
|
Reader io.Reader
|
||||||
Writer net.Conn
|
Writer net.Conn
|
||||||
Context context.Context
|
|
||||||
Cancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -78,6 +76,13 @@ func updateClientSlug(oldSlug, newSlug string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Session) safeClose() {
|
||||||
|
s.once.Do(func() {
|
||||||
|
close(s.ChannelChan)
|
||||||
|
close(s.Done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Session) Close() {
|
func (s *Session) Close() {
|
||||||
if s.Listener != nil {
|
if s.Listener != nil {
|
||||||
s.Listener.Close()
|
s.Listener.Close()
|
||||||
@ -99,7 +104,7 @@ func (s *Session) Close() {
|
|||||||
portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
close(s.Done)
|
s.safeClose()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||||
@ -269,7 +274,6 @@ func (s *Session) acceptTCPConnections() {
|
|||||||
go s.HandleForwardedConnection(UserConnection{
|
go s.HandleForwardedConnection(UserConnection{
|
||||||
Reader: nil,
|
Reader: nil,
|
||||||
Writer: conn,
|
Writer: conn,
|
||||||
Context: context.Background(),
|
|
||||||
}, s.Connection)
|
}, s.Connection)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -538,41 +542,106 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
|
|||||||
}
|
}
|
||||||
defer channel.Close()
|
defer channel.Close()
|
||||||
|
|
||||||
go handleChannelRequests(reqs, conn, channel)
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("Panic in request handler: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for req := range reqs {
|
||||||
|
req.Reply(false, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if conn.Reader == nil {
|
if conn.Reader == nil {
|
||||||
conn.Reader = bufio.NewReader(conn.Writer)
|
conn.Reader = bufio.NewReader(conn.Writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
go io.Copy(channel, conn.Reader)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("Panic in reader copy: %v", r)
|
||||||
|
errChan <- fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := io.Copy(channel, conn.Reader)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
reader := bufio.NewReader(channel)
|
reader := bufio.NewReader(channel)
|
||||||
_, err = reader.Peek(1)
|
|
||||||
if err == io.EOF {
|
|
||||||
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, "Could not forward request to the tunnel addr"))
|
|
||||||
sendBadGatewayResponse(conn.Writer)
|
|
||||||
conn.Writer.Close()
|
|
||||||
channel.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.sendMessage(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp))
|
peekDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := reader.Peek(1)
|
||||||
|
peekDone <- err
|
||||||
|
}()
|
||||||
|
|
||||||
io.Copy(conn.Writer, reader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleChannelRequests(reqs <-chan *ssh.Request, conn UserConnection, channel ssh.Channel) {
|
|
||||||
select {
|
select {
|
||||||
case <-reqs:
|
case err := <-peekDone:
|
||||||
for req := range reqs {
|
if err == io.EOF {
|
||||||
req.Reply(false, nil)
|
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
|
||||||
}
|
sendBadGatewayResponse(conn.Writer)
|
||||||
case <-conn.Context.Done():
|
|
||||||
conn.Writer.Close()
|
|
||||||
channel.Close()
|
|
||||||
log.Println("Connection closed by timeout")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error peeking channel data: %v", err)
|
||||||
|
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
|
||||||
|
sendBadGatewayResponse(conn.Writer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr())
|
||||||
|
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
|
||||||
|
sendBadGatewayResponse(conn.Writer)
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp))
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("Panic in writer copy: %v", r)
|
||||||
|
errChan <- fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := io.Copy(conn.Writer, reader)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Printf("Error copying from channel to conn.Writer: %v", err)
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(errChan)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for err := range errChan {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Connection error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendBadGatewayResponse(writer io.Writer) {
|
func sendBadGatewayResponse(writer io.Writer) {
|
||||||
|
|||||||
@ -3,6 +3,7 @@ package session
|
|||||||
import (
|
import (
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelType string
|
type TunnelType string
|
||||||
@ -24,6 +25,7 @@ type Session struct {
|
|||||||
Slug string
|
Slug string
|
||||||
ChannelChan chan ssh.NewChannel
|
ChannelChan chan ssh.NewChannel
|
||||||
Done chan bool
|
Done chan bool
|
||||||
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
|
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {
|
||||||
|
|||||||
Reference in New Issue
Block a user