fix: correct read/write handling in CustomWriter
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 5m7s

This commit is contained in:
2025-12-06 22:17:55 +07:00
parent 0b8bc1dbba
commit 69c3e78728
7 changed files with 137 additions and 84 deletions

19
main.go
View File

@ -10,28 +10,25 @@ import (
) )
func main() { func main() {
sshConfig := &ssh.ServerConfig{
NoClientAuth: true,
ServerVersion: "SSH-2.0-TunnlPls-1.0",
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
return nil, nil
},
}
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
sshConfig := &ssh.ServerConfig{
NoClientAuth: true,
ServerVersion: "SSH-2.0-TunnlPls-1.0",
}
privateBytes, err := os.ReadFile(utils.Getenv("ssh_private_key")) privateBytes, err := os.ReadFile(utils.Getenv("ssh_private_key"))
if err != nil { if err != nil {
log.Fatalf("Failed to load private key : %s", err.Error()) log.Fatalf("Failed to load private key: %s", err)
} }
private, err := ssh.ParsePrivateKey(privateBytes) private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil { if err != nil {
log.Fatal("Failed to parse private key") log.Fatalf("Failed to parse private key: %s", err)
} }
sshConfig.AddHostKey(private) sshConfig.AddHostKey(private)
app := server.NewServer(*sshConfig) app := server.NewServer(sshConfig)
app.Start() app.Start()
} }

View File

@ -149,8 +149,6 @@ func (req *RequestHeaderFactory) Finalize() []byte {
buf.Write(req.startLine) buf.Write(req.startLine)
buf.WriteString("\r\n") buf.WriteString("\r\n")
req.headers["X-HF"] = "modified"
for key, val := range req.headers { for key, val := range req.headers {
buf.WriteString(key) buf.WriteString(key)
buf.WriteString(": ") buf.WriteString(": ")

View File

@ -30,6 +30,7 @@ type CustomWriter struct {
respMW []ResponseMiddleware respMW []ResponseMiddleware
reqStartMW []RequestMiddleware reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware reqEndMW []RequestMiddleware
overflow []byte
} }
func (cw *CustomWriter) SetInteraction(interaction Interaction) { func (cw *CustomWriter) SetInteraction(interaction Interaction) {
@ -37,9 +38,17 @@ func (cw *CustomWriter) SetInteraction(interaction Interaction) {
} }
func (cw *CustomWriter) Read(p []byte) (int, error) { func (cw *CustomWriter) Read(p []byte) (int, error) {
if len(cw.overflow) > 0 {
n := copy(p, cw.overflow)
cw.overflow = cw.overflow[n:]
if len(cw.overflow) == 0 {
cw.overflow = nil
}
return n, nil
}
tmp := make([]byte, len(p)) tmp := make([]byte, len(p))
read, err := cw.reader.Read(tmp) read, err := cw.reader.Read(tmp)
if err != nil { if read == 0 && err != nil {
return 0, err return 0, err
} }
@ -48,6 +57,9 @@ func (cw *CustomWriter) Read(p []byte) (int, error) {
idx := bytes.Index(tmp, DELIMITER) idx := bytes.Index(tmp, DELIMITER)
if idx == -1 { if idx == -1 {
copy(p, tmp) copy(p, tmp)
if err != nil {
return read, err
}
return read, nil return read, nil
} }
@ -74,18 +86,24 @@ func (cw *CustomWriter) Read(p []byte) (int, error) {
} }
for _, m := range cw.reqStartMW { for _, m := range cw.reqStartMW {
err := m.HandleRequest(reqhf) if mwErr := m.HandleRequest(reqhf); mwErr != nil {
if err != nil { log.Printf("Error when applying request middleware: %v", mwErr)
log.Printf("Error when applying request middleware: %v", err) return 0, mwErr
return 0, err
} }
} }
cw.reqHeader = reqhf cw.reqHeader = reqhf
finalHeader := reqhf.Finalize() finalHeader := reqhf.Finalize()
n := copy(p, finalHeader) combined := append(finalHeader, body...)
n += copy(p[n:], body)
n := copy(p, combined)
if n > len(p) {
cw.overflow = make([]byte, len(combined)-n)
copy(cw.overflow, combined[n:])
log.Printf("output buffer too small (%d vs %d)", len(p), n)
}
return n, nil return n, nil
} }
@ -106,9 +124,7 @@ var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
func isHTTPHeader(buf []byte) bool { func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n")) lines := bytes.Split(buf, []byte("\r\n"))
if len(lines) < 1 {
return false
}
startLine := string(lines[0]) startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) { if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false return false
@ -118,7 +134,8 @@ func isHTTPHeader(buf []byte) bool {
if len(line) == 0 { if len(line) == 0 {
break break
} }
if !bytes.Contains(line, []byte(":")) { colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false return false
} }
} }
@ -130,13 +147,30 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
return cw.writer.Write(p) return cw.writer.Write(p)
} }
if cw.respHeader != nil {
n, err := cw.writer.Write(p)
if err != nil {
return n, err
}
return n, nil
}
cw.buf = append(cw.buf, p...) cw.buf = append(cw.buf, p...)
// TODO: implement middleware buat cache system dll
if idx := bytes.Index(cw.buf, DELIMITER); idx != -1 { idx := bytes.Index(cw.buf, DELIMITER)
if idx == -1 {
return len(p), nil
}
header := cw.buf[:idx+len(DELIMITER)] header := cw.buf[:idx+len(DELIMITER)]
body := cw.buf[idx+len(DELIMITER):] body := cw.buf[idx+len(DELIMITER):]
if isHTTPHeader(header) { if !isHTTPHeader(header) {
n, err := cw.writer.Write(cw.buf)
cw.buf = nil
return n, err
}
resphf := NewResponseHeaderFactory(header) resphf := NewResponseHeaderFactory(header)
for _, m := range cw.respMW { for _, m := range cw.respMW {
err := m.HandleResponse(resphf, body) err := m.HandleResponse(resphf, body)
@ -147,13 +181,13 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
} }
header = resphf.Finalize() header = resphf.Finalize()
cw.respHeader = resphf cw.respHeader = resphf
_, err := cw.writer.Write(header) _, err := cw.writer.Write(header)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if len(body) > 0 { if len(body) > 0 {
_, err := cw.writer.Write(body) _, err = cw.writer.Write(body)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -161,22 +195,6 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
cw.buf = nil cw.buf = nil
return len(p), nil return len(p), nil
} }
}
cw.buf = nil
n, err := cw.writer.Write(p)
if err != nil {
return n, err
}
for _, m := range cw.respMW {
err := m.HandleResponse(cw.respHeader, p)
if err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return 0, err
}
}
return n, nil
}
func (cw *CustomWriter) AddInteraction(interaction Interaction) { func (cw *CustomWriter) AddInteraction(interaction Interaction) {
cw.interaction = interaction cw.interaction = interaction
@ -318,9 +336,11 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
//TODO: Implement wrapper func buat add/remove middleware //TODO: Implement wrapper func buat add/remove middleware
fingerprintMiddleware := NewTunnelFingerprint() fingerprintMiddleware := NewTunnelFingerprint()
loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr) loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr)
forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr)
cw.respMW = append(cw.respMW, fingerprintMiddleware) cw.respMW = append(cw.respMW, fingerprintMiddleware)
cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware) cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware)
cw.reqStartMW = append(cw.reqStartMW, forwardedForMiddleware)
//TODO: Tambah req Middleware //TODO: Tambah req Middleware
cw.reqEndMW = nil cw.reqEndMW = nil
cw.reqHeader = initialRequest cw.reqHeader = initialRequest

View File

@ -19,9 +19,7 @@ type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint { func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{} return &TunnelFingerprint{}
} }
func (h *TunnelFingerprint) HandleRequest(header *RequestHeaderFactory) error {
return nil
}
func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error {
header.Set("Server", "Tunnel Please") header.Set("Server", "Tunnel Please")
return nil return nil
@ -44,7 +42,22 @@ func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error {
return nil return nil
} }
func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil } type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header *RequestHeaderFactory) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
//TODO: Implement caching atau enggak //TODO: Implement caching atau enggak
//const maxCacheSize = 50 * 1024 * 1024 //const maxCacheSize = 50 * 1024 * 1024

View File

@ -16,7 +16,7 @@ type Server struct {
HttpServer *http.Server HttpServer *http.Server
} }
func NewServer(config ssh.ServerConfig) *Server { func NewServer(config *ssh.ServerConfig) *Server {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port"))) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("port")))
if err != nil { if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err) log.Fatalf("failed to listen on port 2200: %v", err)
@ -39,7 +39,7 @@ func NewServer(config ssh.ServerConfig) *Server {
}() }()
return &Server{ return &Server{
Conn: &listener, Conn: &listener,
Config: &config, Config: config,
} }
} }

View File

@ -160,21 +160,29 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
}
return return
} }
if !registerClient(slug, s) {
log.Printf("Failed to register client with slug: %s", slug)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
s.SlugManager.Set(slug) s.SlugManager.Set(slug)
registerClient(slug, s)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(80)) err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to write port to buffer:", err)
unregisterClient(slug)
return return
} }
log.Printf("HTTP forwarding approved on port: %d", 80) log.Printf("HTTP forwarding approved on port: %d", portToBind)
domain := utils.Getenv("domain") domain := utils.Getenv("domain")
protocol := "http" protocol := "http"
@ -184,9 +192,11 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
s.Interaction.ShowWelcomeMessage() s.Interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain))
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
unregisterClient(slug)
return return
} }
} }
@ -194,7 +204,6 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
s.Forwarder.SetType(types.TCP) s.Forwarder.SetType(types.TCP)
log.Printf("Requested forwarding on %s:%d", addr, portToBind) log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
@ -209,25 +218,36 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
} }
return return
} }
s.Forwarder.SetListener(listener)
s.Forwarder.SetForwardedPort(portToBind)
s.Interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
go s.Forwarder.AcceptTCPConnections()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to write port to buffer:", err)
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return return
} }
return
}
log.Printf("TCP forwarding approved on port: %d", portToBind) log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return return
} }
return
}
s.Forwarder.SetListener(listener)
s.Forwarder.SetForwardedPort(portToBind)
s.Interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
go s.Forwarder.AcceptTCPConnections()
} }
func generateUniqueSlug() string { func generateUniqueSlug() string {

View File

@ -30,6 +30,8 @@ type SSHSession struct {
Interaction interaction.Controller Interaction interaction.Controller
Forwarder forwarder.ForwardingController Forwarder forwarder.ForwardingController
SlugManager slug.Manager SlugManager slug.Manager
channelOnce sync.Once
} }
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
@ -73,20 +75,23 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
go session.Lifecycle.WaitForRunningStatus() go session.Lifecycle.WaitForRunningStatus()
for channel := range sshChan { for channel := range sshChan {
ch, reqs, _ := channel.Accept() ch, reqs, err := channel.Accept()
if session.Lifecycle.GetChannel() == nil { if err != nil {
log.Printf("failed to accept channel: %v", err)
continue
}
session.channelOnce.Do(func() {
session.Lifecycle.SetChannel(ch) session.Lifecycle.SetChannel(ch)
session.Interaction.SetChannel(ch) session.Interaction.SetChannel(ch)
session.Lifecycle.SetStatus(types.SETUP) session.Lifecycle.SetStatus(types.SETUP)
go session.HandleGlobalRequest(forwardingReq) go session.HandleGlobalRequest(forwardingReq)
} })
go session.HandleGlobalRequest(reqs) go session.HandleGlobalRequest(reqs)
} }
err := session.Lifecycle.Close() if err := session.Lifecycle.Close(); err != nil {
if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return
}() }()
} }