diff --git a/main.go b/main.go index 1fb275c..9953588 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,20 @@ package main import ( - "golang.org/x/crypto/ssh" "log" + "net/http" + _ "net/http/pprof" "os" "tunnel_pls/server" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() sshConfig := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: "SSH-2.0-TunnlPls-1.0", diff --git a/server/http.go b/server/http.go index 22cf81e..3f1aaba 100644 --- a/server/http.go +++ b/server/http.go @@ -10,7 +10,6 @@ import ( "net" "regexp" "strings" - "time" "tunnel_pls/session" "tunnel_pls/utils" @@ -28,27 +27,66 @@ type CustomWriter struct { reader io.Reader headerBuf []byte buf []byte + respHeader *ResponseHeaderFactory + reqHeader *RequestHeaderFactory interaction *session.Interaction + respMW []ResponseMiddleware + reqStartMW []RequestMiddleware + reqEndMW []RequestMiddleware } func (cw *CustomWriter) Read(p []byte) (int, error) { - if cw == nil { - return 0, errors.New("can not read from nil CustomWriter") - } - read, err := cw.reader.Read(p) + tmp := make([]byte, len(p)) + read, err := cw.reader.Read(tmp) if err != nil { return 0, err } - reader := bytes.NewReader(p) - reqhf, err := NewRequestHeaderFactory(reader) - if err != nil { - if errors.Is(err, io.EOF) { - return read, io.EOF + + tmp = tmp[:read] + + idx := bytes.Index(tmp, DELIMITER) + if idx == -1 { + copy(p, tmp) + return read, nil + } + + header := tmp[:idx+len(DELIMITER)] + body := tmp[idx+len(DELIMITER):] + + if !isHTTPHeader(header) { + copy(p, tmp) + return read, nil + } + + for _, m := range cw.reqEndMW { + err := m.HandleRequest(cw.reqHeader) + if err != nil { + log.Printf("Error when applying request middleware: %v", err) + return 0, err } + } + + headerReader := bufio.NewReader(bytes.NewReader(header)) + reqhf, err := NewRequestHeaderFactory(headerReader) + if err != nil { return 0, err } - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), reqhf.Method, reqhf.Path)) - return read, err + + for _, m := range cw.reqStartMW { + err := m.HandleRequest(reqhf) + if err != nil { + log.Printf("Error when applying request middleware: %v", err) + return 0, err + } + } + + cw.reqHeader = reqhf + finalHeader := reqhf.Finalize() + + n := copy(p, finalHeader) + n += copy(p[n:], body) + + return n, nil } func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { @@ -99,9 +137,15 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { if isHTTPHeader(header) { resphf := NewResponseHeaderFactory(header) - resphf.Set("Server", "Tunnel Please") - + for _, m := range cw.respMW { + err := m.HandleResponse(resphf, body) + if err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return 0, err + } + } header = resphf.Finalize() + cw.respHeader = resphf _, err := cw.writer.Write(header) if err != nil { return 0, err @@ -117,12 +161,19 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return len(p), nil } } + cw.buf = nil n, err := cw.writer.Write(p) if err != nil { return n, err } - + for _, m := range cw.respMW { + err := m.HandleResponse(cw.respHeader, p) + if err != nil { + log.Printf("Cannot apply middleware: %s\n", err) + return 0, err + } + } return n, nil } @@ -272,14 +323,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } } }() - _, err = channel.Write(initialRequest.Finalize()) if err != nil { log.Printf("Failed to forward request: %v", err) return } + //TODO: Implement wrapper func buat add/remove middleware + fingerprintMiddleware := NewTunnelFingerprint() + loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr) + cw.respMW = append(cw.respMW, fingerprintMiddleware) + cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware) - cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), cw.RemoteAddr.String(), initialRequest.Method, initialRequest.Path)) + //TODO: Tambah req Middleware + cw.reqEndMW = nil + cw.reqHeader = initialRequest + + for _, m := range cw.reqStartMW { + err := m.HandleRequest(cw.reqHeader) + if err != nil { + log.Printf("Error handling request: %v", err) + return + } + } sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) return diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000..08ee035 --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,162 @@ +package server + +import ( + "fmt" + "net" + "time" + "tunnel_pls/session" +) + +type RequestMiddleware interface { + HandleRequest(header *RequestHeaderFactory) error +} + +type ResponseMiddleware interface { + HandleResponse(header *ResponseHeaderFactory, body []byte) error +} + +type TunnelFingerprint struct{} + +func NewTunnelFingerprint() *TunnelFingerprint { + return &TunnelFingerprint{} +} +func (h *TunnelFingerprint) HandleRequest(header *RequestHeaderFactory) error { + return nil +} +func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { + header.Set("Server", "Tunnel Please") + return nil +} + +type RequestLogger struct { + interaction session.Interaction + remoteAddr net.Addr +} + +func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger { + return &RequestLogger{ + interaction: *interaction, + remoteAddr: remoteAddr, + } +} +func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error { + rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path)) + return nil +} +func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil } + +//TODO: Implement caching atau enggak +//const maxCacheSize = 50 * 1024 * 1024 +// +//type DiskCacheMiddleware struct { +// dir string +// mu sync.Mutex +// file *os.File +// path string +// cacheable bool +//} +// +//func NewDiskCacheMiddleware() *DiskCacheMiddleware { +// return &DiskCacheMiddleware{dir: "cache"} +//} +// +//func (c *DiskCacheMiddleware) ensureDir() error { +// return os.MkdirAll(c.dir, 0755) +//} +// +//func (c *DiskCacheMiddleware) cacheKey(method, path string) string { +// return fmt.Sprintf("%s_%s.cache", method, base64.URLEncoding.EncodeToString([]byte(path))) +//} +// +//func (c *DiskCacheMiddleware) filePath(method, path string) string { +// return filepath.Join(c.dir, c.cacheKey(method, path)) +//} +// +//func fileExists(path string) bool { +// _, err := os.Stat(path) +// if err == nil { +// return true +// } +// if os.IsNotExist(err) { +// return false +// } +// return false +//} +// +//func canCacheRequest(header *RequestHeaderFactory) bool { +// if header.Method != "GET" { +// return false +// } +// +// if cacheControl := header.Get("Cache-Control"); cacheControl != "" { +// if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "private") || strings.Contains(cacheControl, "no-cache") || strings.Contains(cacheControl, "max-age=0") { +// return false +// } +// } +// +// if header.Get("Authorization") != "" { +// return false +// } +// +// if header.Get("Cookie") != "" { +// return false +// } +// +// return true +//} +// +//func (c *DiskCacheMiddleware) HandleRequest(header *RequestHeaderFactory) error { +// if !canCacheRequest(header) { +// c.cacheable = false +// return nil +// } +// +// c.cacheable = true +// _ = c.ensureDir() +// path := c.filePath(header.Method, header.Path) +// +// if fileExists(path + ".finish") { +// c.file = nil +// return nil +// } +// +// if c.file != nil { +// err := c.file.Close() +// if err != nil { +// return err +// } +// err = os.Rename(c.path, c.path+".finish") +// if err != nil { +// return err +// } +// } +// +// c.path = path +// f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) +// if err != nil { +// return err +// } +// +// c.file = f +// +// return nil +//} +// +//func (c *DiskCacheMiddleware) HandleResponse(header *ResponseHeaderFactory, body []byte) error { +// if !c.cacheable { +// return nil +// } +// +// if c.file == nil { +// header.Set("X-Cache", "HIT") +// return nil +// } +// +// _, err := c.file.Write(body) +// if err != nil { +// return err +// } +// +// header.Set("X-Cache", "MISS") +// return nil +//} diff --git a/session/forwarder.go b/session/forwarder.go new file mode 100644 index 0000000..e7abc17 --- /dev/null +++ b/session/forwarder.go @@ -0,0 +1,37 @@ +package session + +import ( + "net" + + "golang.org/x/crypto/ssh" +) + +type Forwarder struct { + Listener net.Listener + TunnelType TunnelType + ForwardedPort uint16 + + getSlug func() string + setSlug func(string) +} + +type ForwardingController interface { + HandleGlobalRequest(ch <-chan *ssh.Request) + HandleTCPIPForward(req *ssh.Request) + HandleHTTPForward(req *ssh.Request, port uint16) + HandleTCPForward(req *ssh.Request, addr string, port uint16) + AcceptTCPConnections() +} + +type ForwarderInfo interface { + GetTunnelType() TunnelType + GetForwardedPort() uint16 +} + +func (f *Forwarder) GetTunnelType() TunnelType { + return f.TunnelType +} + +func (f *Forwarder) GetForwardedPort() uint16 { + return f.ForwardedPort +} diff --git a/session/interaction.go b/session/interaction.go index b22c87b..cfa1ce1 100644 --- a/session/interaction.go +++ b/session/interaction.go @@ -12,6 +12,32 @@ import ( "golang.org/x/crypto/ssh" ) +type InteractionController interface { + SendMessage(message string) + HandleUserInput() + HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) + HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) + HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) + HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) + HandleSlugUpdateError() + ShowWelcomeMessage() + DisplaySlugEditor() +} + +type Interaction struct { + CommandBuffer *bytes.Buffer + EditMode bool + EditSlug string + channel ssh.Channel + + getSlug func() string + setSlug func(string) + + session SessionCloser + + forwarder ForwarderInfo +} + func (i *Interaction) SendMessage(message string) { if i.channel != nil { _, err := i.channel.Write([]byte(message)) diff --git a/session/session.go b/session/session.go index c44d656..2a38c6a 100644 --- a/session/session.go +++ b/session/session.go @@ -3,7 +3,6 @@ package session import ( "bytes" "log" - "net" "sync" "golang.org/x/crypto/ssh" @@ -31,26 +30,6 @@ type SessionCloser interface { Close() error } -type InteractionController interface { - SendMessage(message string) - HandleUserInput() - HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) - HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) - HandleSlugUpdateError() - ShowWelcomeMessage() - DisplaySlugEditor() -} - -type ForwardingController interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) - AcceptTCPConnections() -} - type Session interface { SessionLifecycle InteractionController @@ -61,41 +40,6 @@ type Lifecycle struct { Status Status } -type Forwarder struct { - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - - getSlug func() string - setSlug func(string) -} - -type ForwarderInfo interface { - GetTunnelType() TunnelType - GetForwardedPort() uint16 -} - -func (f *Forwarder) GetTunnelType() TunnelType { - return f.TunnelType -} - -func (f *Forwarder) GetForwardedPort() uint16 { - return f.ForwardedPort -} - -type Interaction struct { - CommandBuffer *bytes.Buffer - EditMode bool - EditSlug string - channel ssh.Channel - - getSlug func() string - setSlug func(string) - - session SessionCloser - - forwarder ForwarderInfo -} type SSHSession struct { Lifecycle *Lifecycle Interaction *Interaction