fix: conn reader stuck when header have body
Some checks failed
Docker Build and Push / build-and-push (push) Has been cancelled

This commit is contained in:
2025-12-03 21:14:42 +07:00
parent a3eb08e7ae
commit 515bc30559
6 changed files with 314 additions and 74 deletions

View File

@ -1,14 +1,20 @@
package main package main
import ( import (
"golang.org/x/crypto/ssh"
"log" "log"
"net/http"
_ "net/http/pprof"
"os" "os"
"tunnel_pls/server" "tunnel_pls/server"
"tunnel_pls/utils" "tunnel_pls/utils"
"golang.org/x/crypto/ssh"
) )
func main() { func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
sshConfig := &ssh.ServerConfig{ sshConfig := &ssh.ServerConfig{
NoClientAuth: true, NoClientAuth: true,
ServerVersion: "SSH-2.0-TunnlPls-1.0", ServerVersion: "SSH-2.0-TunnlPls-1.0",

View File

@ -10,7 +10,6 @@ import (
"net" "net"
"regexp" "regexp"
"strings" "strings"
"time"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/utils" "tunnel_pls/utils"
@ -28,27 +27,66 @@ type CustomWriter struct {
reader io.Reader reader io.Reader
headerBuf []byte headerBuf []byte
buf []byte buf []byte
respHeader *ResponseHeaderFactory
reqHeader *RequestHeaderFactory
interaction *session.Interaction interaction *session.Interaction
respMW []ResponseMiddleware
reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware
} }
func (cw *CustomWriter) Read(p []byte) (int, error) { func (cw *CustomWriter) Read(p []byte) (int, error) {
if cw == nil { tmp := make([]byte, len(p))
return 0, errors.New("can not read from nil CustomWriter") read, err := cw.reader.Read(tmp)
}
read, err := cw.reader.Read(p)
if err != nil { if err != nil {
return 0, err return 0, err
} }
reader := bytes.NewReader(p)
reqhf, err := NewRequestHeaderFactory(reader) tmp = tmp[:read]
if err != nil {
if errors.Is(err, io.EOF) { idx := bytes.Index(tmp, DELIMITER)
return read, io.EOF if idx == -1 {
copy(p, tmp)
return read, nil
}
header := tmp[:idx+len(DELIMITER)]
body := tmp[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
copy(p, tmp)
return read, nil
}
for _, m := range cw.reqEndMW {
err := m.HandleRequest(cw.reqHeader)
if err != nil {
log.Printf("Error when applying request middleware: %v", err)
return 0, err
} }
}
headerReader := bufio.NewReader(bytes.NewReader(header))
reqhf, err := NewRequestHeaderFactory(headerReader)
if err != nil {
return 0, err return 0, err
} }
cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), reqhf.Method, reqhf.Path))
return read, err for _, m := range cw.reqStartMW {
err := m.HandleRequest(reqhf)
if err != nil {
log.Printf("Error when applying request middleware: %v", err)
return 0, err
}
}
cw.reqHeader = reqhf
finalHeader := reqhf.Finalize()
n := copy(p, finalHeader)
n += copy(p[n:], body)
return n, nil
} }
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter {
@ -99,9 +137,15 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
if isHTTPHeader(header) { if isHTTPHeader(header) {
resphf := NewResponseHeaderFactory(header) resphf := NewResponseHeaderFactory(header)
resphf.Set("Server", "Tunnel Please") for _, m := range cw.respMW {
err := m.HandleResponse(resphf, body)
if err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return 0, err
}
}
header = resphf.Finalize() header = resphf.Finalize()
cw.respHeader = resphf
_, err := cw.writer.Write(header) _, err := cw.writer.Write(header)
if err != nil { if err != nil {
return 0, err return 0, err
@ -117,12 +161,19 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
} }
cw.buf = nil cw.buf = nil
n, err := cw.writer.Write(p) n, err := cw.writer.Write(p)
if err != nil { if err != nil {
return n, err return n, err
} }
for _, m := range cw.respMW {
err := m.HandleResponse(cw.respHeader, p)
if err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return 0, err
}
}
return n, nil return n, nil
} }
@ -272,14 +323,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
} }
} }
}() }()
_, err = channel.Write(initialRequest.Finalize()) _, err = channel.Write(initialRequest.Finalize())
if err != nil { if err != nil {
log.Printf("Failed to forward request: %v", err) log.Printf("Failed to forward request: %v", err)
return return
} }
//TODO: Implement wrapper func buat add/remove middleware
fingerprintMiddleware := NewTunnelFingerprint()
loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr)
cw.respMW = append(cw.respMW, fingerprintMiddleware)
cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware)
cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), initialRequest.Method, initialRequest.Path)) //TODO: Tambah req Middleware
cw.reqEndMW = nil
cw.reqHeader = initialRequest
for _, m := range cw.reqStartMW {
err := m.HandleRequest(cw.reqHeader)
if err != nil {
log.Printf("Error handling request: %v", err)
return
}
}
sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr)
return return

162
server/middleware.go Normal file
View File

@ -0,0 +1,162 @@
package server
import (
"fmt"
"net"
"time"
"tunnel_pls/session"
)
type RequestMiddleware interface {
HandleRequest(header *RequestHeaderFactory) error
}
type ResponseMiddleware interface {
HandleResponse(header *ResponseHeaderFactory, body []byte) error
}
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleRequest(header *RequestHeaderFactory) error {
return nil
}
func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
type RequestLogger struct {
interaction session.Interaction
remoteAddr net.Addr
}
func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger {
return &RequestLogger{
interaction: *interaction,
remoteAddr: remoteAddr,
}
}
func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error {
rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path))
return nil
}
func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil }
//TODO: Implement caching atau enggak
//const maxCacheSize = 50 * 1024 * 1024
//
//type DiskCacheMiddleware struct {
// dir string
// mu sync.Mutex
// file *os.File
// path string
// cacheable bool
//}
//
//func NewDiskCacheMiddleware() *DiskCacheMiddleware {
// return &DiskCacheMiddleware{dir: "cache"}
//}
//
//func (c *DiskCacheMiddleware) ensureDir() error {
// return os.MkdirAll(c.dir, 0755)
//}
//
//func (c *DiskCacheMiddleware) cacheKey(method, path string) string {
// return fmt.Sprintf("%s_%s.cache", method, base64.URLEncoding.EncodeToString([]byte(path)))
//}
//
//func (c *DiskCacheMiddleware) filePath(method, path string) string {
// return filepath.Join(c.dir, c.cacheKey(method, path))
//}
//
//func fileExists(path string) bool {
// _, err := os.Stat(path)
// if err == nil {
// return true
// }
// if os.IsNotExist(err) {
// return false
// }
// return false
//}
//
//func canCacheRequest(header *RequestHeaderFactory) bool {
// if header.Method != "GET" {
// return false
// }
//
// if cacheControl := header.Get("Cache-Control"); cacheControl != "" {
// if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "private") || strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "max-age=0") {
// return false
// }
// }
//
// if header.Get("Authorization") != "" {
// return false
// }
//
// if header.Get("Cookie") != "" {
// return false
// }
//
// return true
//}
//
//func (c *DiskCacheMiddleware) HandleRequest(header *RequestHeaderFactory) error {
// if !canCacheRequest(header) {
// c.cacheable = false
// return nil
// }
//
// c.cacheable = true
// _ = c.ensureDir()
// path := c.filePath(header.Method, header.Path)
//
// if fileExists(path + ".finish") {
// c.file = nil
// return nil
// }
//
// if c.file != nil {
// err := c.file.Close()
// if err != nil {
// return err
// }
// err = os.Rename(c.path, c.path+".finish")
// if err != nil {
// return err
// }
// }
//
// c.path = path
// f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
// if err != nil {
// return err
// }
//
// c.file = f
//
// return nil
//}
//
//func (c *DiskCacheMiddleware) HandleResponse(header *ResponseHeaderFactory, body []byte) error {
// if !c.cacheable {
// return nil
// }
//
// if c.file == nil {
// header.Set("X-Cache", "HIT")
// return nil
// }
//
// _, err := c.file.Write(body)
// if err != nil {
// return err
// }
//
// header.Set("X-Cache", "MISS")
// return nil
//}

37
session/forwarder.go Normal file
View File

@ -0,0 +1,37 @@
package session
import (
"net"
"golang.org/x/crypto/ssh"
)
type Forwarder struct {
Listener net.Listener
TunnelType TunnelType
ForwardedPort uint16
getSlug func() string
setSlug func(string)
}
type ForwardingController interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
AcceptTCPConnections()
}
type ForwarderInfo interface {
GetTunnelType() TunnelType
GetForwardedPort() uint16
}
func (f *Forwarder) GetTunnelType() TunnelType {
return f.TunnelType
}
func (f *Forwarder) GetForwardedPort() uint16 {
return f.ForwardedPort
}

View File

@ -12,6 +12,32 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type InteractionController interface {
SendMessage(message string)
HandleUserInput()
HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer)
HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer)
HandleSlugUpdateError()
ShowWelcomeMessage()
DisplaySlugEditor()
}
type Interaction struct {
CommandBuffer *bytes.Buffer
EditMode bool
EditSlug string
channel ssh.Channel
getSlug func() string
setSlug func(string)
session SessionCloser
forwarder ForwarderInfo
}
func (i *Interaction) SendMessage(message string) { func (i *Interaction) SendMessage(message string) {
if i.channel != nil { if i.channel != nil {
_, err := i.channel.Write([]byte(message)) _, err := i.channel.Write([]byte(message))

View File

@ -3,7 +3,6 @@ package session
import ( import (
"bytes" "bytes"
"log" "log"
"net"
"sync" "sync"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -31,26 +30,6 @@ type SessionCloser interface {
Close() error Close() error
} }
type InteractionController interface {
SendMessage(message string)
HandleUserInput()
HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer)
HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer)
HandleSlugUpdateError()
ShowWelcomeMessage()
DisplaySlugEditor()
}
type ForwardingController interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
AcceptTCPConnections()
}
type Session interface { type Session interface {
SessionLifecycle SessionLifecycle
InteractionController InteractionController
@ -61,41 +40,6 @@ type Lifecycle struct {
Status Status Status Status
} }
type Forwarder struct {
Listener net.Listener
TunnelType TunnelType
ForwardedPort uint16
getSlug func() string
setSlug func(string)
}
type ForwarderInfo interface {
GetTunnelType() TunnelType
GetForwardedPort() uint16
}
func (f *Forwarder) GetTunnelType() TunnelType {
return f.TunnelType
}
func (f *Forwarder) GetForwardedPort() uint16 {
return f.ForwardedPort
}
type Interaction struct {
CommandBuffer *bytes.Buffer
EditMode bool
EditSlug string
channel ssh.Channel
getSlug func() string
setSlug func(string)
session SessionCloser
forwarder ForwarderInfo
}
type SSHSession struct { type SSHSession struct {
Lifecycle *Lifecycle Lifecycle *Lifecycle
Interaction *Interaction Interaction *Interaction