refactor: optimize HTTP header parsing
This commit is contained in:
@ -8,6 +8,7 @@ import (
|
||||
"golang.org/x/net/context"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"tunnel_pls/session"
|
||||
@ -42,10 +43,8 @@ func NewHTTPServer() error {
|
||||
}
|
||||
|
||||
func Handler(conn net.Conn) {
|
||||
//TODO: Determain deadline time/set custom timeout on env
|
||||
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
|
||||
reader := bufio.NewReader(conn)
|
||||
headers, err := peekUntilHeaders(reader, 512)
|
||||
headers, err := peekUntilHeaders(reader, 8192)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to peek headers:", err)
|
||||
return
|
||||
@ -61,7 +60,6 @@ func Handler(conn net.Conn) {
|
||||
|
||||
if len(host) < 1 {
|
||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
fmt.Println("Bad Request")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
@ -80,16 +78,27 @@ func Handler(conn net.Conn) {
|
||||
sshSession, ok := session.Clients[slug]
|
||||
if !ok {
|
||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
fmt.Println("Bad Request 1")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
keepalive, timeout := parseConnectionDetails(headers)
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
if keepalive {
|
||||
if timeout >= 300 {
|
||||
timeout = 300
|
||||
}
|
||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second))
|
||||
} else {
|
||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
||||
}
|
||||
|
||||
sshSession.HandleForwardedConnection(session.UserConnection{
|
||||
Reader: reader,
|
||||
Writer: conn,
|
||||
Context: ctx,
|
||||
}, sshSession.Connection, 80)
|
||||
Cancel: cancel,
|
||||
}, sshSession.Connection)
|
||||
return
|
||||
}
|
||||
|
||||
@ -122,3 +131,42 @@ func parseHostFromHeader(data []byte) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseConnectionDetails(data []byte) (keepAlive bool, timeout int) {
|
||||
keepAlive = false
|
||||
timeout = 30
|
||||
|
||||
lines := strings.Split(string(data), "\r\n")
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(strings.ToLower(line), "connection:") {
|
||||
value := strings.TrimSpace(strings.TrimPrefix(strings.ToLower(line), "connection:"))
|
||||
keepAlive = (value == "keep-alive")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if keepAlive {
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(strings.ToLower(line), "keep-alive:") {
|
||||
value := strings.TrimSpace(strings.TrimPrefix(line, "Keep-Alive:"))
|
||||
|
||||
if strings.Contains(value, "timeout=") {
|
||||
parts := strings.Split(value, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "timeout=") {
|
||||
timeoutStr := strings.TrimPrefix(part, "timeout=")
|
||||
if t, err := strconv.Atoi(timeoutStr); err == nil {
|
||||
timeout = t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keepAlive, timeout
|
||||
}
|
||||
|
||||
@ -43,9 +43,8 @@ func NewHTTPSServer() error {
|
||||
}
|
||||
|
||||
func HandlerTLS(conn net.Conn) {
|
||||
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
|
||||
reader := bufio.NewReader(conn)
|
||||
headers, err := peekUntilHeaders(reader, 512)
|
||||
headers, err := peekUntilHeaders(reader, 8192)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to peek headers:", err)
|
||||
return
|
||||
@ -54,14 +53,12 @@ func HandlerTLS(conn net.Conn) {
|
||||
host := strings.Split(parseHostFromHeader(headers), ".")
|
||||
if len(host) < 1 {
|
||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
fmt.Println("Bad Request")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if len(host) < 1 {
|
||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
fmt.Println("Bad Request")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
@ -70,15 +67,26 @@ func HandlerTLS(conn net.Conn) {
|
||||
sshSession, ok := session.Clients[slug]
|
||||
if !ok {
|
||||
conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
fmt.Println("Bad Request 1")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
keepalive, timeout := parseConnectionDetails(headers)
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
if keepalive {
|
||||
if timeout >= 300 {
|
||||
timeout = 300
|
||||
}
|
||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second))
|
||||
} else {
|
||||
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
||||
}
|
||||
|
||||
sshSession.HandleForwardedConnection(session.UserConnection{
|
||||
Reader: reader,
|
||||
Writer: conn,
|
||||
Context: ctx,
|
||||
}, sshSession.Connection, 80)
|
||||
Cancel: cancel,
|
||||
}, sshSession.Connection)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user