diff --git a/go.mod b/go.mod index 31fdc54..09be3c3 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,8 @@ module tunnel_pls go 1.24.4 require ( - github.com/a-h/templ v0.3.833 github.com/joho/godotenv v1.5.1 - golang.org/x/crypto v0.32.0 - golang.org/x/net v0.33.0 + golang.org/x/crypto v0.45.0 ) -require ( - github.com/gorilla/websocket v1.5.3 // indirect - golang.org/x/sys v0.29.0 // indirect -) +require golang.org/x/sys v0.38.0 // indirect diff --git a/go.sum b/go.sum index e14b727..27269bf 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,13 @@ -github.com/a-h/templ v0.3.833 h1:L/KOk/0VvVTBegtE0fp2RJQiBm7/52Zxv5fqlEHiQUU= -github.com/a-h/templ v0.3.833/go.mod h1:cAu4AiZhtJfBjMY0HASlyzvkrtjnHWPeEsyGK2YYmfk= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= diff --git a/server/http.go b/server/http.go index 3f1aaba..0960932 100644 --- a/server/http.go +++ b/server/http.go @@ -11,16 +11,13 @@ import ( "regexp" "strings" "tunnel_pls/session" + "tunnel_pls/types" "tunnel_pls/utils" - - "golang.org/x/crypto/ssh" ) -var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + - "Content-Length: 11\r\n" + - "Content-Type: text/plain\r\n\r\n" + - "Bad Gateway") - +type Interaction interface { + SendMessage(message string) +} type CustomWriter struct { RemoteAddr net.Addr writer io.Writer @@ -29,12 +26,16 @@ type CustomWriter struct { buf []byte respHeader *ResponseHeaderFactory reqHeader *RequestHeaderFactory - interaction *session.Interaction + interaction Interaction respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware } +func (cw *CustomWriter) SetInteraction(interaction Interaction) { + cw.interaction = interaction +} + func (cw *CustomWriter) Read(p []byte) (int, error) { tmp := make([]byte, len(p)) read, err := cw.reader.Read(tmp) @@ -125,7 +126,7 @@ func isHTTPHeader(buf []byte) bool { } func (cw *CustomWriter) Write(p []byte) (int, error) { - if len(p) == len(BadGatewayResponse) && bytes.Equal(p, BadGatewayResponse) { + if len(p) == len(types.BadGatewayResponse) && bytes.Equal(p, types.BadGatewayResponse) { return cw.writer.Write(p) } @@ -177,7 +178,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return n, nil } -func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { +func (cw *CustomWriter) AddInteraction(interaction Interaction) { cw.interaction = interaction } @@ -211,7 +212,7 @@ func NewHTTPServer() error { func Handler(conn net.Conn) { defer func() { err := conn.Close() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { log.Printf("Error closing connection: %v", err) return } @@ -287,32 +288,18 @@ func Handler(conn net.Conn) { return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - + cw.SetInteraction(sshSession.Interaction) forwardRequest(cw, reqhf, sshSession) return } func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { - cw.AddInteraction(sshSession.Interaction) - originHost, originPort := ParseAddr(cw.RemoteAddr.String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), sshSession.Forwarder.GetForwardedPort()) - channel, reqs, err := sshSession.Conn.OpenChannel("forwarded-tcpip", payload) + payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) + channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) - sendBadGatewayResponse(cw) return } - defer func(channel ssh.Channel) { - err := channel.Close() - if err != nil { - if errors.Is(err, io.EOF) { - sendBadGatewayResponse(cw) - return - } - log.Println("Failed to close connection:", err) - return - } - }(channel) go func() { for req := range reqs { @@ -346,14 +333,6 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } } - sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) + sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) return } - -func sendBadGatewayResponse(writer io.Writer) { - _, err := writer.Write(BadGatewayResponse) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return - } -} diff --git a/server/https.go b/server/https.go index f4ecf99..cbe7c86 100644 --- a/server/https.go +++ b/server/https.go @@ -112,7 +112,7 @@ func HandlerTLS(conn net.Conn) { return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - + cw.SetInteraction(sshSession.Interaction) forwardRequest(cw, reqhf, sshSession) return } diff --git a/server/middleware.go b/server/middleware.go index 08ee035..a28bdab 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "time" - "tunnel_pls/session" ) type RequestMiddleware interface { @@ -29,20 +28,22 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [ } type RequestLogger struct { - interaction session.Interaction + interaction Interaction remoteAddr net.Addr } -func NewRequestLogger(interaction *session.Interaction, remoteAddr net.Addr) *RequestLogger { +func NewRequestLogger(interaction Interaction, remoteAddr net.Addr) *RequestLogger { return &RequestLogger{ - interaction: *interaction, + interaction: interaction, remoteAddr: remoteAddr, } } + func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error { rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path)) return nil } + func (rl *RequestLogger) HandleResponse(header *ResponseHeaderFactory, body []byte) error { return nil } //TODO: Implement caching atau enggak diff --git a/server/server.go b/server/server.go index 0e6bdb6..9d01817 100644 --- a/server/server.go +++ b/server/server.go @@ -1,13 +1,10 @@ package server import ( - "bytes" - "encoding/binary" "fmt" "log" "net" "net/http" - "strconv" "tunnel_pls/utils" "golang.org/x/crypto/ssh" @@ -58,41 +55,3 @@ func (s *Server) Start() { go s.handleConnection(conn) } } - -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { - var buf bytes.Buffer - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - - return buf.Bytes() -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) -} diff --git a/session/forwarder.go b/session/forwarder.go deleted file mode 100644 index e7abc17..0000000 --- a/session/forwarder.go +++ /dev/null @@ -1,37 +0,0 @@ -package session - -import ( - "net" - - "golang.org/x/crypto/ssh" -) - -type Forwarder struct { - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - - getSlug func() string - setSlug func(string) -} - -type ForwardingController interface { - HandleGlobalRequest(ch <-chan *ssh.Request) - HandleTCPIPForward(req *ssh.Request) - HandleHTTPForward(req *ssh.Request, port uint16) - HandleTCPForward(req *ssh.Request, addr string, port uint16) - AcceptTCPConnections() -} - -type ForwarderInfo interface { - GetTunnelType() TunnelType - GetForwardedPort() uint16 -} - -func (f *Forwarder) GetTunnelType() TunnelType { - return f.TunnelType -} - -func (f *Forwarder) GetForwardedPort() uint16 { - return f.ForwardedPort -} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go new file mode 100644 index 0000000..3d846e6 --- /dev/null +++ b/session/forwarder/forwarder.go @@ -0,0 +1,185 @@ +package forwarder + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "log" + "net" + "strconv" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type Forwarder struct { + Listener net.Listener + TunnelType types.TunnelType + ForwardedPort uint16 + SlugManager slug.Manager + Lifecycle Lifecycle +} + +type Lifecycle interface { + GetConnection() ssh.Conn +} + +type ForwardingController interface { + AcceptTCPConnections() + SetType(tunnelType types.TunnelType) + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 + SetForwardedPort(port uint16) + SetListener(listener net.Listener) + GetListener() net.Listener + Close() error + HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) + SetLifecycle(lifecycle Lifecycle) + CreateForwardedTCPIPPayload(origin net.Addr) []byte + WriteBadGatewayResponse(dst io.Writer) +} + +func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { + f.Lifecycle = lifecycle +} + +func (f *Forwarder) AcceptTCPConnections() { + for { + conn, err := f.GetListener().Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + log.Printf("Error accepting connection: %v", err) + continue + } + payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) + channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + return + } + + go func() { + for req := range reqs { + err := req.Reply(false, nil) + if err != nil { + log.Printf("Failed to reply to request: %v", err) + return + } + } + }() + go f.HandleConnection(conn, channel, conn.RemoteAddr()) + } +} + +func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { + defer func(src ssh.Channel) { + _, err := io.Copy(io.Discard, src) + if err != nil { + log.Printf("Failed to discard connection: %v", err) + } + + err = src.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing connection: %v", err) + } + }(src) + log.Printf("Handling new forwarded connection from %s", remoteAddr) + + go func() { + _, err := io.Copy(src, dst) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + log.Printf("Error copying from conn.Reader to channel: %v", err) + } + }() + + _, err := io.Copy(dst, src) + + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from channel to conn.Writer: %v", err) + } + return +} + +func (f *Forwarder) SetType(tunnelType types.TunnelType) { + f.TunnelType = tunnelType +} + +func (f *Forwarder) GetTunnelType() types.TunnelType { + return f.TunnelType +} + +func (f *Forwarder) GetForwardedPort() uint16 { + return f.ForwardedPort +} + +func (f *Forwarder) SetForwardedPort(port uint16) { + f.ForwardedPort = port +} + +func (f *Forwarder) SetListener(listener net.Listener) { + f.Listener = listener +} + +func (f *Forwarder) GetListener() net.Listener { + return f.Listener +} + +func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { + _, err := dst.Write(types.BadGatewayResponse) + if err != nil { + log.Printf("failed to write Bad Gateway response: %v", err) + return + } +} + +func (f *Forwarder) Close() error { + if f.GetTunnelType() != types.HTTP { + return f.Listener.Close() + } + return nil +} + +func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { + var buf bytes.Buffer + + host, originPort := parseAddr(origin.String()) + + writeSSHString(&buf, "localhost") + err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort())) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + + writeSSHString(&buf, host) + err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return nil + } + + return buf.Bytes() +} + +func parseAddr(addr string) (string, uint16) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) + return "0.0.0.0", uint16(0) + } + port, _ := strconv.Atoi(portStr) + return host, uint16(port) +} + +func writeSSHString(buffer *bytes.Buffer, str string) { + err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) + if err != nil { + log.Printf("Failed to write string to buffer: %v", err) + return + } + buffer.WriteString(str) +} diff --git a/session/handler.go b/session/handler.go index 5c63338..e2a77f7 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,98 +3,24 @@ package session import ( "bytes" "encoding/binary" - "errors" "fmt" - "io" "log" "net" - "strconv" - "sync" - "time" portUtil "tunnel_pls/internal/port" + "tunnel_pls/types" "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) -type Status string - -var forbiddenSlug = []string{ - "ping", -} - -type UserConnection struct { - Reader io.Reader - Writer net.Conn -} - -var ( - clientsMutex sync.RWMutex - Clients = make(map[string]*SSHSession) -) - -func registerClient(slug string, session *SSHSession) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[slug]; exists { - return false - } - - Clients[slug] = session - return true -} - -func unregisterClient(slug string) { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - delete(Clients, slug) -} - -func (s *SSHSession) Close() error { - if s.Forwarder.Listener != nil { - err := s.Forwarder.Listener.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err - } - } - - if s.channel != nil { - err := s.channel.Close() - if err != nil && !errors.Is(err, io.EOF) { - return err - } - } - - if s.Conn != nil { - err := s.Conn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - return err - } - } - - slug := s.Forwarder.getSlug() - if slug != "" { - unregisterClient(slug) - } - - if s.Forwarder.TunnelType == TCP && s.Forwarder.Listener != nil { - err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) - if err != nil { - return err - } - } - - return nil -} +var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { for req := range GlobalRequest { switch req.Type { case "tcpip-forward": - s.handleTCPIPForward(req) + s.HandleTCPIPForward(req) return case "shell", "pty-req", "window-change": err := req.Reply(true, nil) @@ -113,7 +39,7 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { } } -func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { +func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { log.Println("Port forwarding request detected") reader := bytes.NewReader(req.Payload) @@ -126,7 +52,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -142,7 +68,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -156,7 +82,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -172,7 +98,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -180,11 +106,11 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { } s.Interaction.SendMessage("\033[H\033[2J") - s.Lifecycle.Status = RUNNING + s.Lifecycle.SetStatus(types.RUNNING) go s.Interaction.HandleUserInput() if portToBind == 80 || portToBind == 443 { - s.handleHTTPForward(req, portToBind) + s.HandleHTTPForward(req, portToBind) return } else { if portToBind == 0 { @@ -197,7 +123,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -210,7 +136,7 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -222,29 +148,12 @@ func (s *SSHSession) handleTCPIPForward(req *ssh.Request) { return } } - s.handleTCPForward(req, addr, portToBind) + s.HandleTCPForward(req, addr, portToBind) } -var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} - -func isBlockedPort(port uint16) bool { - if port == 80 || port == 443 { - return false - } - if port < 1024 && port != 0 { - return true - } - for _, p := range blockedReservedPorts { - if p == port { - return true - } - } - return false -} - -func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { - s.Forwarder.TunnelType = HTTP - s.Forwarder.ForwardedPort = portToBind +func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { + s.Forwarder.SetType(types.HTTP) + s.Forwarder.SetForwardedPort(portToBind) slug := generateUniqueSlug() if slug == "" { @@ -256,7 +165,7 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.Forwarder.setSlug(slug) + s.SlugManager.Set(slug) registerClient(slug, s) buf := new(bytes.Buffer) @@ -282,8 +191,8 @@ func (s *SSHSession) handleHTTPForward(req *ssh.Request, portToBind uint16) { } } -func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - s.Forwarder.TunnelType = TCP +func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { + s.Forwarder.SetType(types.TCP) log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) @@ -294,18 +203,18 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind log.Println("Failed to reply to request:", err) return } - err = s.Close() + err = s.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } return } - s.Forwarder.Listener = listener - s.Forwarder.ForwardedPort = portToBind + s.Forwarder.SetListener(listener) + s.Forwarder.SetForwardedPort(portToBind) s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.TunnelType, utils.Getenv("domain"), s.Forwarder.ForwardedPort)) + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) - go s.acceptTCPConnections() + go s.Forwarder.AcceptTCPConnections() buf := new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) @@ -321,37 +230,6 @@ func (s *SSHSession) handleTCPForward(req *ssh.Request, addr string, portToBind } } -func (s *SSHSession) acceptTCPConnections() { - for { - conn, err := s.Forwarder.Listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - log.Printf("Error accepting connection: %v", err) - continue - } - originHost, originPort := ParseAddr(conn.RemoteAddr().String()) - payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort()) - channel, reqs, err := s.Conn.OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - return - } - - go func() { - for req := range reqs { - err := req.Reply(false, nil) - if err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } - }() - go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr()) - } -} - func generateUniqueSlug() string { maxAttempts := 5 @@ -371,95 +249,6 @@ func generateUniqueSlug() string { return "" } -func (s *SSHSession) waitForRunningStatus() { - timeout := time.After(3 * time.Second) - ticker := time.NewTicker(150 * time.Millisecond) - defer ticker.Stop() - frames := []string{"-", "\\", "|", "/"} - i := 0 - for { - select { - case <-ticker.C: - s.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) - i = (i + 1) % len(frames) - if s.Lifecycle.Status == RUNNING { - s.Interaction.SendMessage("\r\033[K") - return - } - case <-timeout: - s.Interaction.SendMessage("\r\033[K") - s.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") - err := s.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } - log.Println("Timeout waiting for session to start running") - return - } - } -} - -func isForbiddenSlug(slug string) bool { - for _, s := range forbiddenSlug { - if slug == s { - return true - } - } - return false -} - -func isValidSlug(slug string) bool { - if len(slug) < 3 || len(slug) > 20 { - return false - } - - if slug[0] == '-' || slug[len(slug)-1] == '-' { - return false - } - - for _, c := range slug { - if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { - return false - } - } - - return true -} - -func waitForKeyPress(connection ssh.Channel) { - keyBuf := make([]byte, 1) - for { - _, err := connection.Read(keyBuf) - if err == nil { - break - } - } -} - -func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { - defer func(src ssh.Channel) { - err := src.Close() - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing connection: %v", err) - } - }(src) - log.Printf("Handling new forwarded connection from %s", remoteAddr) - - go func() { - _, err := io.Copy(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying from conn.Reader to channel: %v", err) - } - }() - - _, err := io.Copy(dst, src) - - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error copying from channel to conn.Writer: %v", err) - } - return -} - func readSSHString(reader *bytes.Reader) (string, error) { var length uint32 if err := binary.Read(reader, binary.BigEndian, &length); err != nil { @@ -472,40 +261,17 @@ func readSSHString(reader *bytes.Reader) (string, error) { return string(strBytes), nil } -func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { - var buf bytes.Buffer - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(port)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil +func isBlockedPort(port uint16) bool { + if port == 80 || port == 443 { + return false } - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil + if port < 1024 && port != 0 { + return true } - - return buf.Bytes() -} - -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return - } - buffer.WriteString(str) -} - -func ParseAddr(addr string) (string, uint32) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint32(0) - } - port, _ := strconv.Atoi(portStr) - return host, uint32(port) + for _, p := range blockedReservedPorts { + if p == port { + return true + } + } + return false } diff --git a/session/interaction.go b/session/interaction/interaction.go similarity index 75% rename from session/interaction.go rename to session/interaction/interaction.go index cfa1ce1..0c998c4 100644 --- a/session/interaction.go +++ b/session/interaction/interaction.go @@ -1,4 +1,4 @@ -package session +package interaction import ( "bytes" @@ -7,35 +7,60 @@ import ( "log" "strings" "time" + "tunnel_pls/session/slug" + "tunnel_pls/types" "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) -type InteractionController interface { +var forbiddenSlug = []string{ + "ping", +} + +type Lifecycle interface { + Close() error +} + +type Controller interface { SendMessage(message string) HandleUserInput() - HandleCommand(conn ssh.Channel, command string, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugEditMode(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, char byte, buf *bytes.Buffer) - HandleSlugSave(conn ssh.Channel, inSlugEditMode *bool, editSlug *string, buf *bytes.Buffer) - HandleSlugCancel(conn ssh.Channel, inSlugEditMode *bool, buf *bytes.Buffer) + HandleCommand(command string) + HandleSlugEditMode(connection ssh.Channel, char byte) + HandleSlugSave(conn ssh.Channel) + HandleSlugCancel(connection ssh.Channel) HandleSlugUpdateError() ShowWelcomeMessage() DisplaySlugEditor() + SetChannel(channel ssh.Channel) + SetLifecycle(lifecycle Lifecycle) + SetSlugModificator(func(oldSlug, newSlug string) bool) +} + +type Forwarder interface { + Close() error + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 } type Interaction struct { - CommandBuffer *bytes.Buffer - EditMode bool - EditSlug string - channel ssh.Channel + InputLength int + CommandBuffer *bytes.Buffer + EditMode bool + EditSlug string + channel ssh.Channel + SlugManager slug.Manager + Forwarder Forwarder + Lifecycle Lifecycle + updateClientSlug func(oldSlug, newSlug string) bool +} - getSlug func() string - setSlug func(string) +func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { + i.Lifecycle = lifecycle +} - session SessionCloser - - forwarder ForwarderInfo +func (i *Interaction) SetChannel(channel ssh.Channel) { + i.channel = channel } func (i *Interaction) SendMessage(message string) { @@ -49,7 +74,6 @@ func (i *Interaction) SendMessage(message string) { } func (i *Interaction) HandleUserInput() { - var commandBuffer bytes.Buffer buf := make([]byte, 1) i.EditMode = false @@ -66,42 +90,47 @@ func (i *Interaction) HandleUserInput() { char := buf[0] if i.EditMode { - i.HandleSlugEditMode(i.channel, char, &commandBuffer) + i.HandleSlugEditMode(i.channel, char) continue } i.SendMessage(string(buf[:n])) if char == 8 || char == 127 { - if commandBuffer.Len() > 0 { - commandBuffer.Truncate(commandBuffer.Len() - 1) + if i.InputLength > 0 { i.SendMessage("\b \b") } + if i.CommandBuffer.Len() > 0 { + i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) + } continue } + i.InputLength += n + if char == '/' { - commandBuffer.Reset() - commandBuffer.WriteByte(char) + i.CommandBuffer.Reset() + i.CommandBuffer.WriteByte(char) continue } - if commandBuffer.Len() > 0 { + if i.CommandBuffer.Len() > 0 { if char == 13 { - i.HandleCommand(commandBuffer.String(), &commandBuffer) + i.SendMessage("\033[K") + i.HandleCommand(i.CommandBuffer.String()) continue } - commandBuffer.WriteByte(char) + i.CommandBuffer.WriteByte(char) } } } } -func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) { if char == 13 { i.HandleSlugSave(connection) } else if char == 27 { - i.HandleSlugCancel(connection, commandBuffer) + i.HandleSlugCancel(connection) } else if char == 8 || char == 127 { if len(i.EditSlug) > 0 { i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1] @@ -142,10 +171,10 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { return } if isValid { - oldSlug := i.getSlug() + oldSlug := i.SlugManager.Get() newSlug := i.EditSlug - if !updateClientSlug(oldSlug, newSlug) { + if !i.updateClientSlug(oldSlug, newSlug) { i.HandleSlugUpdateError() return } @@ -223,7 +252,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { if utils.Getenv("tls_enabled") == "true" { protocol = "https" } - _, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain))) + _, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))) if err != nil { log.Printf("failed to write to channel: %v", err) return @@ -233,7 +262,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { i.CommandBuffer.Reset() } -func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleSlugCancel(connection ssh.Channel) { i.EditMode = false _, err := connection.Write([]byte("\033[H\033[2J")) if err != nil { @@ -260,7 +289,7 @@ func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *by } i.ShowWelcomeMessage() - commandBuffer.Reset() + i.CommandBuffer.Reset() } func (i *Interaction) HandleSlugUpdateError() { @@ -271,44 +300,44 @@ func (i *Interaction) HandleSlugUpdateError() { i.SendMessage(fmt.Sprintf("Disconnecting in %d...\r\n", iter)) time.Sleep(1 * time.Second) } - err := i.session.Close() + err := i.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) return } } -func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) { +func (i *Interaction) HandleCommand(command string) { switch command { case "/bye": i.SendMessage("\r\nClosing connection...") - err := i.session.Close() + err := i.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) return } return case "/help": - i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug") + i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug\r\n") case "/clear": i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() domain := utils.Getenv("domain") - if i.forwarder.GetTunnelType() == HTTP { + if i.Forwarder.GetTunnelType() == types.HTTP { protocol := "http" if utils.Getenv("tls_enabled") == "true" { protocol = "https" } - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.getSlug(), domain)) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) } else { - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", i.forwarder.GetTunnelType(), domain, i.forwarder.GetForwardedPort())) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) } case "/slug": - if i.forwarder.GetTunnelType() != HTTP { - i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.forwarder.GetTunnelType()))) + if i.Forwarder.GetTunnelType() != types.HTTP { + i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType())) } else { i.EditMode = true - i.EditSlug = i.getSlug() + i.EditSlug = i.SlugManager.Get() i.SendMessage("\033[H\033[2J") i.DisplaySlugEditor() i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain")) @@ -317,7 +346,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) i.SendMessage("Unknown command") } - commandBuffer.Reset() + i.CommandBuffer.Reset() } func (i *Interaction) ShowWelcomeMessage() { @@ -347,7 +376,7 @@ func (i *Interaction) ShowWelcomeMessage() { func (i *Interaction) DisplaySlugEditor() { domain := utils.Getenv("domain") - fullDomain := i.getSlug() + "." + domain + fullDomain := i.SlugManager.Get() + "." + domain const paddingRight = 4 @@ -383,23 +412,8 @@ func (i *Interaction) DisplaySlugEditor() { i.SendMessage("\r\n\r\n") } -func updateClientSlug(oldSlug, newSlug string) bool { - clientsMutex.Lock() - defer clientsMutex.Unlock() - - if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { - return false - } - - client, ok := Clients[oldSlug] - if !ok { - return false - } - - delete(Clients, oldSlug) - client.Forwarder.setSlug(newSlug) - Clients[newSlug] = client - return true +func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) bool) { + i.updateClientSlug = modificator } func centerText(text string, width int) string { @@ -409,3 +423,40 @@ func centerText(text string, width int) string { } return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding) } + +func isValidSlug(slug string) bool { + if len(slug) < 3 || len(slug) > 20 { + return false + } + + if slug[0] == '-' || slug[len(slug)-1] == '-' { + return false + } + + for _, c := range slug { + if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { + return false + } + } + + return true +} + +func waitForKeyPress(connection ssh.Channel) { + keyBuf := make([]byte, 1) + for { + _, err := connection.Read(keyBuf) + if err == nil { + break + } + } +} + +func isForbiddenSlug(slug string) bool { + for _, s := range forbiddenSlug { + if slug == s { + return true + } + } + return false +} diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go new file mode 100644 index 0000000..29b02ed --- /dev/null +++ b/session/lifecycle/lifecycle.go @@ -0,0 +1,126 @@ +package lifecycle + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "time" + portUtil "tunnel_pls/internal/port" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type Interaction interface { + SendMessage(string) +} + +type Forwarder interface { + Close() error + GetTunnelType() types.TunnelType + GetForwardedPort() uint16 +} + +type Lifecycle struct { + Status types.Status + Conn ssh.Conn + Channel ssh.Channel + + Interaction Interaction + Forwarder Forwarder + SlugManager slug.Manager + unregisterClient func(slug string) +} + +func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { + l.unregisterClient = unregisterClient +} + +type SessionLifecycle interface { + Close() error + WaitForRunningStatus() + SetStatus(status types.Status) + GetConnection() ssh.Conn + GetChannel() ssh.Channel + SetChannel(channel ssh.Channel) + SetUnregisterClient(unregisterClient func(slug string)) +} + +func (l *Lifecycle) GetChannel() ssh.Channel { + return l.Channel +} + +func (l *Lifecycle) SetChannel(channel ssh.Channel) { + l.Channel = channel +} +func (l *Lifecycle) GetConnection() ssh.Conn { + return l.Conn +} +func (l *Lifecycle) SetStatus(status types.Status) { + l.Status = status +} +func (l *Lifecycle) WaitForRunningStatus() { + timeout := time.After(3 * time.Second) + ticker := time.NewTicker(150 * time.Millisecond) + defer ticker.Stop() + frames := []string{"-", "\\", "|", "/"} + i := 0 + for { + select { + case <-ticker.C: + l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) + i = (i + 1) % len(frames) + if l.Status == types.RUNNING { + l.Interaction.SendMessage("\r\033[K") + return + } + case <-timeout: + l.Interaction.SendMessage("\r\033[K") + l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") + err := l.Close() + if err != nil { + log.Printf("failed to close session: %v", err) + } + log.Println("Timeout waiting for session to start running") + return + } + } +} + +func (l *Lifecycle) Close() error { + err := l.Forwarder.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + + if l.Channel != nil { + err := l.Channel.Close() + if err != nil && !errors.Is(err, io.EOF) { + return err + } + } + + if l.Conn != nil { + err := l.Conn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + return err + } + } + + clientSlug := l.SlugManager.Get() + if clientSlug != "" { + l.unregisterClient(clientSlug) + } + + if l.Forwarder.GetTunnelType() == types.TCP { + err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false) + if err != nil { + return err + } + } + + return nil +} diff --git a/session/session.go b/session/session.go index 2a38c6a..e122e38 100644 --- a/session/session.go +++ b/session/session.go @@ -4,102 +4,85 @@ import ( "bytes" "log" "sync" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) -const ( - INITIALIZING Status = "INITIALIZING" - RUNNING Status = "RUNNING" - SETUP Status = "SETUP" +var ( + clientsMutex sync.RWMutex + Clients = make(map[string]*SSHSession) ) -type TunnelType string - -const ( - HTTP TunnelType = "http" - TCP TunnelType = "tcp" -) - -type SessionLifecycle interface { - Close() error - WaitForRunningStatus() -} - -type SessionCloser interface { - Close() error -} - type Session interface { - SessionLifecycle - InteractionController - ForwardingController -} - -type Lifecycle struct { - Status Status + HandleGlobalRequest(ch <-chan *ssh.Request) + HandleTCPIPForward(req *ssh.Request) + HandleHTTPForward(req *ssh.Request, port uint16) + HandleTCPForward(req *ssh.Request, addr string, port uint16) } type SSHSession struct { - Lifecycle *Lifecycle - Interaction *Interaction - Forwarder *Forwarder - - Conn *ssh.ServerConn - channel ssh.Channel - - slug string - slugMu sync.RWMutex + Lifecycle lifecycle.SessionLifecycle + Interaction interaction.Controller + Forwarder forwarder.ForwardingController + SlugManager slug.Manager } func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { - session := SSHSession{ - Lifecycle: &Lifecycle{ - Status: INITIALIZING, - }, - Interaction: &Interaction{ - CommandBuffer: new(bytes.Buffer), - EditMode: false, - EditSlug: "", - channel: nil, - getSlug: nil, - setSlug: nil, - session: nil, - forwarder: nil, - }, - Forwarder: &Forwarder{ - Listener: nil, - TunnelType: "", - ForwardedPort: 0, - getSlug: nil, - setSlug: nil, - }, - Conn: conn, - channel: nil, - slug: "", + slugManager := slug.NewManager() + forwarderManager := &forwarder.Forwarder{ + Listener: nil, + TunnelType: "", + ForwardedPort: 0, + SlugManager: slugManager, + } + interactionManager := &interaction.Interaction{ + CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), + EditMode: false, + EditSlug: "", + SlugManager: slugManager, + Forwarder: forwarderManager, + Lifecycle: nil, + } + lifecycleManager := &lifecycle.Lifecycle{ + Status: "", + Conn: conn, + Channel: nil, + Interaction: interactionManager, + Forwarder: forwarderManager, + SlugManager: slugManager, } - session.Forwarder.getSlug = session.GetSlug - session.Forwarder.setSlug = session.SetSlug - session.Interaction.getSlug = session.GetSlug - session.Interaction.setSlug = session.SetSlug - session.Interaction.session = &session - session.Interaction.forwarder = session.Forwarder + interactionManager.SetLifecycle(lifecycleManager) + interactionManager.SetSlugModificator(updateClientSlug) + forwarderManager.SetLifecycle(lifecycleManager) + lifecycleManager.SetUnregisterClient(unregisterClient) + + session := &SSHSession{ + Lifecycle: lifecycleManager, + Interaction: interactionManager, + Forwarder: forwarderManager, + SlugManager: slugManager, + } go func() { - go session.waitForRunningStatus() + go session.Lifecycle.WaitForRunningStatus() for channel := range sshChan { ch, reqs, _ := channel.Accept() - if session.channel == nil { - session.channel = ch - session.Interaction.channel = ch - session.Lifecycle.Status = SETUP + if session.Lifecycle.GetChannel() == nil { + session.Lifecycle.SetChannel(ch) + session.Interaction.SetChannel(ch) + session.Lifecycle.SetStatus(types.SETUP) go session.HandleGlobalRequest(forwardingReq) } go session.HandleGlobalRequest(reqs) } - err := session.Close() + err := session.Lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -107,14 +90,40 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan }() } -func (s *SSHSession) GetSlug() string { - s.slugMu.RLock() - defer s.slugMu.RUnlock() - return s.slug +func updateClientSlug(oldSlug, newSlug string) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[newSlug]; exists && newSlug != oldSlug { + return false + } + + client, ok := Clients[oldSlug] + if !ok { + return false + } + + delete(Clients, oldSlug) + client.SlugManager.Set(newSlug) + Clients[newSlug] = client + return true } -func (s *SSHSession) SetSlug(slug string) { - s.slugMu.Lock() - s.slug = slug - s.slugMu.Unlock() +func registerClient(slug string, session *SSHSession) bool { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + if _, exists := Clients[slug]; exists { + return false + } + + Clients[slug] = session + return true +} + +func unregisterClient(slug string) { + clientsMutex.Lock() + defer clientsMutex.Unlock() + + delete(Clients, slug) } diff --git a/session/slug/slug.go b/session/slug/slug.go new file mode 100644 index 0000000..4900e22 --- /dev/null +++ b/session/slug/slug.go @@ -0,0 +1,32 @@ +package slug + +import "sync" + +type Manager interface { + Get() string + Set(slug string) +} + +type manager struct { + slug string + slugMu sync.RWMutex +} + +func NewManager() Manager { + return &manager{ + slug: "", + slugMu: sync.RWMutex{}, + } +} + +func (s *manager) Get() string { + s.slugMu.RLock() + defer s.slugMu.RUnlock() + return s.slug +} + +func (s *manager) Set(slug string) { + s.slugMu.Lock() + s.slug = slug + s.slugMu.Unlock() +} diff --git a/types/types.go b/types/types.go new file mode 100644 index 0000000..f909da5 --- /dev/null +++ b/types/types.go @@ -0,0 +1,21 @@ +package types + +type Status string + +const ( + INITIALIZING Status = "INITIALIZING" + RUNNING Status = "RUNNING" + SETUP Status = "SETUP" +) + +type TunnelType string + +const ( + HTTP TunnelType = "HTTP" + TCP TunnelType = "TCP" +) + +var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + + "Content-Length: 11\r\n" + + "Content-Type: text/plain\r\n\r\n" + + "Bad Gateway")