Merge pull request 'fix: conn reader stuck when header have body' (#18) from staging into main
Some checks failed
Docker Build and Push / build-and-push (push) Has been cancelled
Some checks failed
Docker Build and Push / build-and-push (push) Has been cancelled
Reviewed-on: bagas/tunnl_please#18
This commit is contained in:
8
main.go
8
main.go
@ -1,14 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"tunnel_pls/server"
|
||||
"tunnel_pls/utils"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
sshConfig := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
ServerVersion: "SSH-2.0-TunnlPls-1.0",
|
||||
|
||||
@ -10,7 +10,6 @@ import (
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"tunnel_pls/session"
|
||||
"tunnel_pls/utils"
|
||||
|
||||
@ -28,27 +27,66 @@ type CustomWriter struct {
|
||||
reader io.Reader
|
||||
headerBuf []byte
|
||||
buf []byte
|
||||
respHeader *ResponseHeaderFactory
|
||||
reqHeader *RequestHeaderFactory
|
||||
interaction *session.Interaction
|
||||
respMW []ResponseMiddleware
|
||||
reqStartMW []RequestMiddleware
|
||||
reqEndMW []RequestMiddleware
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) Read(p []byte) (int, error) {
|
||||
if cw == nil {
|
||||
return 0, errors.New("can not read from nil CustomWriter")
|
||||
}
|
||||
read, err := cw.reader.Read(p)
|
||||
tmp := make([]byte, len(p))
|
||||
read, err := cw.reader.Read(tmp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
reader := bytes.NewReader(p)
|
||||
reqhf, err := NewRequestHeaderFactory(reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return read, io.EOF
|
||||
|
||||
tmp = tmp[:read]
|
||||
|
||||
idx := bytes.Index(tmp, DELIMITER)
|
||||
if idx == -1 {
|
||||
copy(p, tmp)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
header := tmp[:idx+len(DELIMITER)]
|
||||
body := tmp[idx+len(DELIMITER):]
|
||||
|
||||
if !isHTTPHeader(header) {
|
||||
copy(p, tmp)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
for _, m := range cw.reqEndMW {
|
||||
err := m.HandleRequest(cw.reqHeader)
|
||||
if err != nil {
|
||||
log.Printf("Error when applying request middleware: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
headerReader := bufio.NewReader(bytes.NewReader(header))
|
||||
reqhf, err := NewRequestHeaderFactory(headerReader)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), reqhf.Method, reqhf.Path))
|
||||
return read, err
|
||||
|
||||
for _, m := range cw.reqStartMW {
|
||||
err := m.HandleRequest(reqhf)
|
||||
if err != nil {
|
||||
log.Printf("Error when applying request middleware: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
cw.reqHeader = reqhf
|
||||
finalHeader := reqhf.Finalize()
|
||||
|
||||
n := copy(p, finalHeader)
|
||||
n += copy(p[n:], body)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter {
|
||||
@ -99,9 +137,15 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
|
||||
|
||||
if isHTTPHeader(header) {
|
||||
resphf := NewResponseHeaderFactory(header)
|
||||
resphf.Set("Server", "Tunnel Please")
|
||||
|
||||
for _, m := range cw.respMW {
|
||||
err := m.HandleResponse(resphf, body)
|
||||
if err != nil {
|
||||
log.Printf("Cannot apply middleware: %s\n", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
header = resphf.Finalize()
|
||||
cw.respHeader = resphf
|
||||
_, err := cw.writer.Write(header)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@ -117,12 +161,19 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
}
|
||||
|
||||
cw.buf = nil
|
||||
n, err := cw.writer.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
for _, m := range cw.respMW {
|
||||
err := m.HandleResponse(cw.respHeader, p)
|
||||
if err != nil {
|
||||
log.Printf("Cannot apply middleware: %s\n", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@ -272,14 +323,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = channel.Write(initialRequest.Finalize())
|
||||
if err != nil {
|
||||
log.Printf("Failed to forward request: %v", err)
|
||||
return
|
||||
}
|
||||
//TODO: Implement wrapper func buat add/remove middleware
|
||||
fingerprintMiddleware := NewTunnelFingerprint()
|
||||
loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr)
|
||||
cw.respMW = append(cw.respMW, fingerprintMiddleware)
|
||||
cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware)
|
||||
|
||||
cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), initialRequest.Method, initialRequest.Path))
|
||||
//TODO: Tambah req Middleware
|
||||
cw.reqEndMW = nil
|
||||
cw.reqHeader = initialRequest
|
||||
|
||||
for _, m := range cw.reqStartMW {
|
||||
err := m.HandleRequest(cw.reqHeader)
|
||||
if err != nil {
|
||||
log.Printf("Error handling request: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr)
|
||||
return
|
||||
|
||||
162
server/middleware.go
Normal file
162
server/middleware.go
Normal file
@ -0,0 +1,162 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
"tunnel_pls/session"
|
||||
)
|
||||
|
||||
type RequestMiddleware interface {
|
||||
HandleRequest(header *RequestHeaderFactory) error
|
||||
}
|
||||
|
||||
type ResponseMiddleware interface {
|
||||
HandleResponse(header *ResponseHeaderFactory, body []byte) error
|
||||
}
|
||||
|
||||
type TunnelFingerprint struct{}
|
||||
|
||||
func NewTunnelFingerprint() *TunnelFingerprint {
|
||||
return &TunnelFingerprint{}
|
||||
}
|
||||
func (h *TunnelFingerprint) HandleRequest(header *RequestHeaderFactory) error {
|
||||
return nil
|
||||
}
|
||||
func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error {
|
||||
header.Set("Server", "Tunnel Please")
|
||||
return nil
|
||||
}
|
||||
|
||||
type RequestLogger struct {
|
||||
interaction session.Interaction
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger {
|
||||
return &RequestLogger{
|
||||
interaction: *interaction,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
}
|
||||
func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error {
|
||||
rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path))
|
||||
return nil
|
||||
}
|
||||
func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil }
|
||||
|
||||
//TODO: Implement caching atau enggak
|
||||
//const maxCacheSize = 50 * 1024 * 1024
|
||||
//
|
||||
//type DiskCacheMiddleware struct {
|
||||
// dir string
|
||||
// mu sync.Mutex
|
||||
// file *os.File
|
||||
// path string
|
||||
// cacheable bool
|
||||
//}
|
||||
//
|
||||
//func NewDiskCacheMiddleware() *DiskCacheMiddleware {
|
||||
// return &DiskCacheMiddleware{dir: "cache"}
|
||||
//}
|
||||
//
|
||||
//func (c *DiskCacheMiddleware) ensureDir() error {
|
||||
// return os.MkdirAll(c.dir, 0755)
|
||||
//}
|
||||
//
|
||||
//func (c *DiskCacheMiddleware) cacheKey(method, path string) string {
|
||||
// return fmt.Sprintf("%s_%s.cache", method, base64.URLEncoding.EncodeToString([]byte(path)))
|
||||
//}
|
||||
//
|
||||
//func (c *DiskCacheMiddleware) filePath(method, path string) string {
|
||||
// return filepath.Join(c.dir, c.cacheKey(method, path))
|
||||
//}
|
||||
//
|
||||
//func fileExists(path string) bool {
|
||||
// _, err := os.Stat(path)
|
||||
// if err == nil {
|
||||
// return true
|
||||
// }
|
||||
// if os.IsNotExist(err) {
|
||||
// return false
|
||||
// }
|
||||
// return false
|
||||
//}
|
||||
//
|
||||
//func canCacheRequest(header *RequestHeaderFactory) bool {
|
||||
// if header.Method != "GET" {
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// if cacheControl := header.Get("Cache-Control"); cacheControl != "" {
|
||||
// if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "private") || strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "max-age=0") {
|
||||
// return false
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if header.Get("Authorization") != "" {
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// if header.Get("Cookie") != "" {
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// return true
|
||||
//}
|
||||
//
|
||||
//func (c *DiskCacheMiddleware) HandleRequest(header *RequestHeaderFactory) error {
|
||||
// if !canCacheRequest(header) {
|
||||
// c.cacheable = false
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// c.cacheable = true
|
||||
// _ = c.ensureDir()
|
||||
// path := c.filePath(header.Method, header.Path)
|
||||
//
|
||||
// if fileExists(path + ".finish") {
|
||||
// c.file = nil
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// if c.file != nil {
|
||||
// err := c.file.Close()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// err = os.Rename(c.path, c.path+".finish")
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// c.path = path
|
||||
// f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// c.file = f
|
||||
//
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//func (c *DiskCacheMiddleware) HandleResponse(header *ResponseHeaderFactory, body []byte) error {
|
||||
// if !c.cacheable {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// if c.file == nil {
|
||||
// header.Set("X-Cache", "HIT")
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// _, err := c.file.Write(body)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//
|
||||
// header.Set("X-Cache", "MISS")
|
||||
// return nil
|
||||
//}
|
||||
37
session/forwarder.go
Normal file
37
session/forwarder.go
Normal file
@ -0,0 +1,37 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
Listener net.Listener
|
||||
TunnelType TunnelType
|
||||
ForwardedPort uint16
|
||||
|
||||
getSlug func() string
|
||||
setSlug func(string)
|
||||
}
|
||||
|
||||
type ForwardingController interface {
|
||||
HandleGlobalRequest(ch <-chan *ssh.Request)
|
||||
HandleTCPIPForward(req *ssh.Request)
|
||||
HandleHTTPForward(req *ssh.Request, port uint16)
|
||||
HandleTCPForward(req *ssh.Request, addr string, port uint16)
|
||||
AcceptTCPConnections()
|
||||
}
|
||||
|
||||
type ForwarderInfo interface {
|
||||
GetTunnelType() TunnelType
|
||||
GetForwardedPort() uint16
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetTunnelType() TunnelType {
|
||||
return f.TunnelType
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetForwardedPort() uint16 {
|
||||
return f.ForwardedPort
|
||||
}
|
||||
@ -12,6 +12,32 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type InteractionController interface {
|
||||
SendMessage(message string)
|
||||
HandleUserInput()
|
||||
HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
|
||||
HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer)
|
||||
HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
|
||||
HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer)
|
||||
HandleSlugUpdateError()
|
||||
ShowWelcomeMessage()
|
||||
DisplaySlugEditor()
|
||||
}
|
||||
|
||||
type Interaction struct {
|
||||
CommandBuffer *bytes.Buffer
|
||||
EditMode bool
|
||||
EditSlug string
|
||||
channel ssh.Channel
|
||||
|
||||
getSlug func() string
|
||||
setSlug func(string)
|
||||
|
||||
session SessionCloser
|
||||
|
||||
forwarder ForwarderInfo
|
||||
}
|
||||
|
||||
func (i *Interaction) SendMessage(message string) {
|
||||
if i.channel != nil {
|
||||
_, err := i.channel.Write([]byte(message))
|
||||
|
||||
@ -3,7 +3,6 @@ package session
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -31,26 +30,6 @@ type SessionCloser interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type InteractionController interface {
|
||||
SendMessage(message string)
|
||||
HandleUserInput()
|
||||
HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
|
||||
HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer)
|
||||
HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer)
|
||||
HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer)
|
||||
HandleSlugUpdateError()
|
||||
ShowWelcomeMessage()
|
||||
DisplaySlugEditor()
|
||||
}
|
||||
|
||||
type ForwardingController interface {
|
||||
HandleGlobalRequest(ch <-chan *ssh.Request)
|
||||
HandleTCPIPForward(req *ssh.Request)
|
||||
HandleHTTPForward(req *ssh.Request, port uint16)
|
||||
HandleTCPForward(req *ssh.Request, addr string, port uint16)
|
||||
AcceptTCPConnections()
|
||||
}
|
||||
|
||||
type Session interface {
|
||||
SessionLifecycle
|
||||
InteractionController
|
||||
@ -61,41 +40,6 @@ type Lifecycle struct {
|
||||
Status Status
|
||||
}
|
||||
|
||||
type Forwarder struct {
|
||||
Listener net.Listener
|
||||
TunnelType TunnelType
|
||||
ForwardedPort uint16
|
||||
|
||||
getSlug func() string
|
||||
setSlug func(string)
|
||||
}
|
||||
|
||||
type ForwarderInfo interface {
|
||||
GetTunnelType() TunnelType
|
||||
GetForwardedPort() uint16
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetTunnelType() TunnelType {
|
||||
return f.TunnelType
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetForwardedPort() uint16 {
|
||||
return f.ForwardedPort
|
||||
}
|
||||
|
||||
type Interaction struct {
|
||||
CommandBuffer *bytes.Buffer
|
||||
EditMode bool
|
||||
EditSlug string
|
||||
channel ssh.Channel
|
||||
|
||||
getSlug func() string
|
||||
setSlug func(string)
|
||||
|
||||
session SessionCloser
|
||||
|
||||
forwarder ForwarderInfo
|
||||
}
|
||||
type SSHSession struct {
|
||||
Lifecycle *Lifecycle
|
||||
Interaction *Interaction
|
||||
|
||||
Reference in New Issue
Block a user