staging #50

Merged
bagas merged 7 commits from staging into main 2025-12-29 10:17:00 +00:00
13 changed files with 338 additions and 223 deletions
Showing only changes of commit 2644b4521c - Show all commits

View File

@@ -16,7 +16,7 @@ COPY . .
RUN --mount=type=cache,target=/go/pkg/mod \ RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \ CGO_ENABLED=0 GOOS=linux \
go build -trimpath \ go build -trimpath \
-ldflags="-w -s" \ -ldflags="-w -s" \
-o /app/tunnel_pls \ -o /app/tunnel_pls \

View File

@@ -9,7 +9,7 @@ import (
) )
func (s *Server) handleConnection(conn net.Conn) { func (s *Server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config) sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err := conn.Close() err := conn.Close()

View File

@@ -14,21 +14,38 @@ type HeaderManager interface {
Finalize() []byte Finalize() []byte
} }
type ResponseHeaderFactory struct { type ResponseHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type RequestHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type responseHeaderFactory struct {
startLine []byte startLine []byte
headers map[string]string headers map[string]string
} }
type RequestHeaderFactory struct { type requestHeaderFactory struct {
Method string method string
Path string path string
Version string version string
startLine []byte startLine []byte
headers map[string]string headers map[string]string
} }
func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) { func NewRequestHeaderFactory(br *bufio.Reader) (RequestHeaderManager, error) {
header := &RequestHeaderFactory{ header := &requestHeaderFactory{
headers: make(map[string]string), headers: make(map[string]string),
} }
@@ -44,9 +61,9 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) {
return nil, fmt.Errorf("invalid request line") return nil, fmt.Errorf("invalid request line")
} }
header.Method = parts[0] header.method = parts[0]
header.Path = parts[1] header.path = parts[1]
header.Version = parts[2] header.version = parts[2]
for { for {
line, err := br.ReadString('\n') line, err := br.ReadString('\n')
@@ -69,8 +86,8 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) {
return header, nil return header, nil
} }
func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory { func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
header := &ResponseHeaderFactory{ header := &responseHeaderFactory{
startLine: nil, startLine: nil,
headers: make(map[string]string), headers: make(map[string]string),
} }
@@ -96,19 +113,19 @@ func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory {
return header return header
} }
func (resp *ResponseHeaderFactory) Get(key string) string { func (resp *responseHeaderFactory) Get(key string) string {
return resp.headers[key] return resp.headers[key]
} }
func (resp *ResponseHeaderFactory) Set(key string, value string) { func (resp *responseHeaderFactory) Set(key string, value string) {
resp.headers[key] = value resp.headers[key] = value
} }
func (resp *ResponseHeaderFactory) Remove(key string) { func (resp *responseHeaderFactory) Remove(key string) {
delete(resp.headers, key) delete(resp.headers, key)
} }
func (resp *ResponseHeaderFactory) Finalize() []byte { func (resp *responseHeaderFactory) Finalize() []byte {
var buf bytes.Buffer var buf bytes.Buffer
buf.Write(resp.startLine) buf.Write(resp.startLine)
@@ -125,7 +142,7 @@ func (resp *ResponseHeaderFactory) Finalize() []byte {
return buf.Bytes() return buf.Bytes()
} }
func (req *RequestHeaderFactory) Get(key string) string { func (req *requestHeaderFactory) Get(key string) string {
val, ok := req.headers[key] val, ok := req.headers[key]
if !ok { if !ok {
return "" return ""
@@ -133,15 +150,27 @@ func (req *RequestHeaderFactory) Get(key string) string {
return val return val
} }
func (req *RequestHeaderFactory) Set(key string, value string) { func (req *requestHeaderFactory) Set(key string, value string) {
req.headers[key] = value req.headers[key] = value
} }
func (req *RequestHeaderFactory) Remove(key string) { func (req *requestHeaderFactory) Remove(key string) {
delete(req.headers, key) delete(req.headers, key)
} }
func (req *RequestHeaderFactory) Finalize() []byte { func (req *requestHeaderFactory) GetMethod() string {
return req.method
}
func (req *requestHeaderFactory) GetPath() string {
return req.path
}
func (req *requestHeaderFactory) GetVersion() string {
return req.version
}
func (req *requestHeaderFactory) Finalize() []byte {
var buf bytes.Buffer var buf bytes.Buffer
buf.Write(req.startLine) buf.Write(req.startLine)

View File

@@ -20,25 +20,63 @@ import (
type Interaction interface { type Interaction interface {
SendMessage(message string) SendMessage(message string)
} }
type CustomWriter struct {
RemoteAddr net.Addr type HTTPWriter interface {
io.Reader
io.Writer
SetInteraction(interaction Interaction)
AddInteraction(interaction Interaction)
GetRemoteAddr() net.Addr
GetWriter() io.Writer
AddResponseMiddleware(mw ResponseMiddleware)
AddRequestStartMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeaderManager)
GetRequestStartMiddleware() []RequestMiddleware
}
type customWriter struct {
remoteAddr net.Addr
writer io.Writer writer io.Writer
reader io.Reader reader io.Reader
headerBuf []byte headerBuf []byte
buf []byte buf []byte
respHeader *ResponseHeaderFactory respHeader ResponseHeaderManager
reqHeader *RequestHeaderFactory reqHeader RequestHeaderManager
interaction Interaction interaction Interaction
respMW []ResponseMiddleware respMW []ResponseMiddleware
reqStartMW []RequestMiddleware reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware reqEndMW []RequestMiddleware
} }
func (cw *CustomWriter) SetInteraction(interaction Interaction) { func (cw *customWriter) SetInteraction(interaction Interaction) {
cw.interaction = interaction cw.interaction = interaction
} }
func (cw *CustomWriter) Read(p []byte) (int, error) { func (cw *customWriter) GetRemoteAddr() net.Addr {
return cw.remoteAddr
}
func (cw *customWriter) GetWriter() io.Writer {
return cw.writer
}
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
cw.respMW = append(cw.respMW, mw)
}
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
cw.reqStartMW = append(cw.reqStartMW, mw)
}
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
cw.reqHeader = header
}
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
return cw.reqStartMW
}
func (cw *customWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p)) tmp := make([]byte, len(p))
read, err := cw.reader.Read(tmp) read, err := cw.reader.Read(tmp)
if read == 0 && err != nil { if read == 0 && err != nil {
@@ -95,9 +133,9 @@ func (cw *CustomWriter) Read(p []byte) (int, error) {
return n, nil return n, nil
} }
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &CustomWriter{ return &customWriter{
RemoteAddr: remoteAddr, remoteAddr: remoteAddr,
writer: writer, writer: writer,
reader: reader, reader: reader,
buf: make([]byte, 0, 4096), buf: make([]byte, 0, 4096),
@@ -129,7 +167,7 @@ func isHTTPHeader(buf []byte) bool {
return true return true
} }
func (cw *CustomWriter) Write(p []byte) (int, error) { func (cw *customWriter) Write(p []byte) (int, error) {
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" { if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
cw.respHeader = nil cw.respHeader = nil
} }
@@ -186,7 +224,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (cw *CustomWriter) AddInteraction(interaction Interaction) { func (cw *customWriter) AddInteraction(interaction Interaction) {
cw.interaction = interaction cw.interaction = interaction
} }
@@ -292,13 +330,13 @@ func Handler(conn net.Conn) {
return return
} }
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
cw.SetInteraction(sshSession.Interaction) cw.SetInteraction(sshSession.GetInteraction())
forwardRequest(cw, reqhf, sshSession) forwardRequest(cw, reqhf, sshSession)
return return
} }
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
@@ -308,7 +346,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -319,29 +357,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
case result := <-resultChan: case result := <-resultChan:
if result.err != nil { if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
channel = result.channel channel = result.channel
reqs = result.reqs reqs = result.reqs
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel") log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
fingerprintMiddleware := NewTunnelFingerprint() fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr) forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
cw.respMW = append(cw.respMW, fingerprintMiddleware) cw.AddResponseMiddleware(fingerprintMiddleware)
cw.reqStartMW = append(cw.reqStartMW, forwardedForMiddleware) cw.AddRequestStartMiddleware(forwardedForMiddleware)
cw.reqEndMW = nil cw.SetRequestHeader(initialRequest)
cw.reqHeader = initialRequest
for _, m := range cw.reqStartMW { for _, m := range cw.GetRequestStartMiddleware() {
if err := m.HandleRequest(cw.reqHeader); err != nil { if err := m.HandleRequest(initialRequest); err != nil {
log.Printf("Error handling request: %v", err) log.Printf("Error handling request: %v", err)
return return
} }
@@ -353,6 +390,6 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
return return
} }
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return return
} }

View File

@@ -104,7 +104,7 @@ func HandlerTLS(conn net.Conn) {
return return
} }
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
cw.SetInteraction(sshSession.Interaction) cw.SetInteraction(sshSession.GetInteraction())
forwardRequest(cw, reqhf, sshSession) forwardRequest(cw, reqhf, sshSession)
return return
} }

View File

@@ -5,11 +5,11 @@ import (
) )
type RequestMiddleware interface { type RequestMiddleware interface {
HandleRequest(header *RequestHeaderFactory) error HandleRequest(header RequestHeaderManager) error
} }
type ResponseMiddleware interface { type ResponseMiddleware interface {
HandleResponse(header *ResponseHeaderFactory, body []byte) error HandleResponse(header ResponseHeaderManager, body []byte) error
} }
type TunnelFingerprint struct{} type TunnelFingerprint struct{}
@@ -18,16 +18,11 @@ func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{} return &TunnelFingerprint{}
} }
func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
header.Set("Server", "Tunnel Please") header.Set("Server", "Tunnel Please")
return nil return nil
} }
type RequestLogger struct {
interaction Interaction
remoteAddr net.Addr
}
type ForwardedFor struct { type ForwardedFor struct {
addr net.Addr addr net.Addr
} }
@@ -36,7 +31,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr} return &ForwardedFor{addr: addr}
} }
func (ff *ForwardedFor) HandleRequest(header *RequestHeaderFactory) error { func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
host, _, err := net.SplitHostPort(ff.addr.String()) host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil { if err != nil {
return err return err

View File

@@ -11,9 +11,21 @@ import (
) )
type Server struct { type Server struct {
Conn *net.Listener conn *net.Listener
Config *ssh.ServerConfig config *ssh.ServerConfig
HttpServer *http.Server httpServer *http.Server
}
func (s *Server) GetConn() *net.Listener {
return s.conn
}
func (s *Server) GetConfig() *ssh.ServerConfig {
return s.config
}
func (s *Server) GetHttpServer() *http.Server {
return s.httpServer
} }
func NewServer(config *ssh.ServerConfig) *Server { func NewServer(config *ssh.ServerConfig) *Server {
@@ -33,15 +45,15 @@ func NewServer(config *ssh.ServerConfig) *Server {
log.Fatalf("failed to start http server: %v", err) log.Fatalf("failed to start http server: %v", err)
} }
return &Server{ return &Server{
Conn: &listener, conn: &listener,
Config: config, config: config,
} }
} }
func (s *Server) Start() { func (s *Server) Start() {
log.Println("SSH server is starting on port 2200...") log.Println("SSH server is starting on port 2200...")
for { for {
conn, err := (*s.Conn).Accept() conn, err := (*s.conn).Accept()
if err != nil { if err != nil {
log.Printf("failed to accept connection: %v", err) log.Printf("failed to accept connection: %v", err)
continue continue

View File

@@ -16,7 +16,16 @@ import (
"github.com/libdns/cloudflare" "github.com/libdns/cloudflare"
) )
type TLSManager struct { type TLSManager interface {
userCertsExistAndValid() bool
loadUserCerts() error
startCertWatcher()
initCertMagic() error
getTLSConfig() *tls.Config
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
type tlsManager struct {
domain string domain string
certPath string certPath string
keyPath string keyPath string
@@ -30,7 +39,7 @@ type TLSManager struct {
useCertMagic bool useCertMagic bool
} }
var tlsManager *TLSManager var globalTLSManager TLSManager
var tlsManagerOnce sync.Once var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) { func NewTLSConfig(domain string) (*tls.Config, error) {
@@ -41,7 +50,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
keyPath := "certs/tls/privkey.pem" keyPath := "certs/tls/privkey.pem"
storagePath := "certs/tls/certmagic" storagePath := "certs/tls/certmagic"
tm := &TLSManager{ tm := &tlsManager{
domain: domain, domain: domain,
certPath: certPath, certPath: certPath,
keyPath: keyPath, keyPath: keyPath,
@@ -72,14 +81,14 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = true tm.useCertMagic = true
} }
tlsManager = tm globalTLSManager = tm
}) })
if initErr != nil { if initErr != nil {
return nil, initErr return nil, initErr
} }
return tlsManager.getTLSConfig(), nil return globalTLSManager.getTLSConfig(), nil
} }
func isACMEConfigComplete() bool { func isACMEConfigComplete() bool {
@@ -87,7 +96,7 @@ func isACMEConfigComplete() bool {
return cfAPIToken != "" return cfAPIToken != ""
} }
func (tm *TLSManager) userCertsExistAndValid() bool { func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath) log.Printf("Certificate file not found: %s", tm.certPath)
return false return false
@@ -158,7 +167,7 @@ func ValidateCertDomains(certPath, domain string) bool {
return hasBase && hasWildcard return hasBase && hasWildcard
} }
func (tm *TLSManager) loadUserCerts() error { func (tm *tlsManager) loadUserCerts() error {
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath) cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
if err != nil { if err != nil {
return err return err
@@ -172,7 +181,7 @@ func (tm *TLSManager) loadUserCerts() error {
return nil return nil
} }
func (tm *TLSManager) startCertWatcher() { func (tm *tlsManager) startCertWatcher() {
go func() { go func() {
var lastCertMod, lastKeyMod time.Time var lastCertMod, lastKeyMod time.Time
@@ -227,7 +236,7 @@ func (tm *TLSManager) startCertWatcher() {
}() }()
} }
func (tm *TLSManager) initCertMagic() error { func (tm *tlsManager) initCertMagic() error {
if err := os.MkdirAll(tm.storagePath, 0700); err != nil { if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
return fmt.Errorf("failed to create cert storage directory: %w", err) return fmt.Errorf("failed to create cert storage directory: %w", err)
} }
@@ -289,14 +298,14 @@ func (tm *TLSManager) initCertMagic() error {
return nil return nil
} }
func (tm *TLSManager) getTLSConfig() *tls.Config { func (tm *tlsManager) getTLSConfig() *tls.Config {
return &tls.Config{ return &tls.Config{
GetCertificate: tm.getCertificate, GetCertificate: tm.getCertificate,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
} }
func (tm *TLSManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if tm.useCertMagic { if tm.useCertMagic {
return tm.magic.GetCertificate(hello) return tm.magic.GetCertificate(hello)
} }

View File

@@ -31,11 +31,21 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
} }
type Forwarder struct { type Forwarder struct {
Listener net.Listener listener net.Listener
TunnelType types.TunnelType tunnelType types.TunnelType
ForwardedPort uint16 forwardedPort uint16
SlugManager slug.Manager slugManager slug.Manager
Lifecycle Lifecycle lifecycle Lifecycle
}
func NewForwarder(slugManager slug.Manager) *Forwarder {
return &Forwarder{
listener: nil,
tunnelType: "",
forwardedPort: 0,
slugManager: slugManager,
lifecycle: nil,
}
} }
type Lifecycle interface { type Lifecycle interface {
@@ -58,7 +68,7 @@ type ForwardingController interface {
} }
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
f.Lifecycle = lifecycle f.lifecycle = lifecycle
} }
func (f *Forwarder) AcceptTCPConnections() { func (f *Forwarder) AcceptTCPConnections() {
@@ -90,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() {
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -164,27 +174,27 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
} }
func (f *Forwarder) SetType(tunnelType types.TunnelType) { func (f *Forwarder) SetType(tunnelType types.TunnelType) {
f.TunnelType = tunnelType f.tunnelType = tunnelType
} }
func (f *Forwarder) GetTunnelType() types.TunnelType { func (f *Forwarder) GetTunnelType() types.TunnelType {
return f.TunnelType return f.tunnelType
} }
func (f *Forwarder) GetForwardedPort() uint16 { func (f *Forwarder) GetForwardedPort() uint16 {
return f.ForwardedPort return f.forwardedPort
} }
func (f *Forwarder) SetForwardedPort(port uint16) { func (f *Forwarder) SetForwardedPort(port uint16) {
f.ForwardedPort = port f.forwardedPort = port
} }
func (f *Forwarder) SetListener(listener net.Listener) { func (f *Forwarder) SetListener(listener net.Listener) {
f.Listener = listener f.listener = listener
} }
func (f *Forwarder) GetListener() net.Listener { func (f *Forwarder) GetListener() net.Listener {
return f.Listener return f.listener
} }
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
@@ -197,7 +207,7 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
func (f *Forwarder) Close() error { func (f *Forwarder) Close() error {
if f.GetListener() != nil { if f.GetListener() != nil {
return f.Listener.Close() return f.listener.Close()
} }
return nil return nil
} }

View File

@@ -49,7 +49,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -59,13 +59,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
var rawPortToBind uint32 var rawPortToBind uint32
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
log.Println("Failed to read port from payload:", err) log.Println("Failed to read port from payload:", err)
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -73,13 +73,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
} }
if rawPortToBind > 65535 { if rawPortToBind > 65535 {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -89,13 +89,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
portToBind := uint16(rawPortToBind) portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) { if isBlockedPort(portToBind) {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -110,26 +110,26 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
unassign, success := portUtil.Default.GetUnassignedPort() unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign portToBind = unassign
if !success { if !success {
s.Interaction.SendMessage("No available port\r\n") s.interaction.SendMessage("No available port\r\n")
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return return
} }
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse { } else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -193,21 +193,21 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
return return
} }
s.Forwarder.SetType(types.HTTP) s.forwarder.SetType(types.HTTP)
s.Forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.SlugManager.Set(slug) s.slugManager.Set(slug)
s.Interaction.SendMessage("\033[H\033[2J") s.interaction.SendMessage("\033[H\033[2J")
s.Interaction.ShowWelcomeMessage() s.interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain))
s.Lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
s.Interaction.HandleUserInput() s.interaction.HandleUserInput()
} }
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
log.Printf("Requested forwarding on %s:%d", addr, portToBind) log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr) log.Printf("Failed to reset port status: %v", setErr)
} }
@@ -216,7 +216,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -253,15 +253,15 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
return return
} }
s.Forwarder.SetType(types.TCP) s.forwarder.SetType(types.TCP)
s.Forwarder.SetListener(listener) s.forwarder.SetListener(listener)
s.Forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.Interaction.SendMessage("\033[H\033[2J") s.interaction.SendMessage("\033[H\033[2J")
s.Interaction.ShowWelcomeMessage() s.interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("DOMAIN", "localhost"), s.Forwarder.GetForwardedPort())) s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("DOMAIN", "localhost"), s.forwarder.GetForwardedPort()))
s.Lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
go s.Forwarder.AcceptTCPConnections() go s.forwarder.AcceptTCPConnections()
s.Interaction.HandleUserInput() s.interaction.HandleUserInput()
} }
func generateUniqueSlug() string { func generateUniqueSlug() string {

View File

@@ -42,21 +42,37 @@ type Forwarder interface {
} }
type Interaction struct { type Interaction struct {
InputLength int inputLength int
CommandBuffer *bytes.Buffer commandBuffer *bytes.Buffer
InteractiveMode bool interactiveMode bool
InteractionType types.InteractionType interactionType types.InteractionType
EditSlug string editSlug string
channel ssh.Channel channel ssh.Channel
SlugManager slug.Manager slugManager slug.Manager
Forwarder Forwarder forwarder Forwarder
Lifecycle Lifecycle lifecycle Lifecycle
pendingExit bool pendingExit bool
updateClientSlug func(oldSlug, newSlug string) bool updateClientSlug func(oldSlug, newSlug string) bool
} }
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
return &Interaction{
inputLength: 0,
commandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
interactiveMode: false,
interactionType: "",
editSlug: "",
channel: nil,
slugManager: slugManager,
forwarder: forwarder,
lifecycle: nil,
pendingExit: false,
updateClientSlug: nil,
}
}
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
i.Lifecycle = lifecycle i.lifecycle = lifecycle
} }
func (i *Interaction) SetChannel(channel ssh.Channel) { func (i *Interaction) SetChannel(channel ssh.Channel) {
@@ -77,7 +93,7 @@ func (i *Interaction) SendMessage(message string) {
func (i *Interaction) HandleUserInput() { func (i *Interaction) HandleUserInput() {
buf := make([]byte, 1) buf := make([]byte, 1)
i.InteractiveMode = false i.interactiveMode = false
for { for {
n, err := i.channel.Read(buf) n, err := i.channel.Read(buf)
@@ -99,7 +115,7 @@ func (i *Interaction) handleReadError(err error) {
} }
func (i *Interaction) processCharacter(char byte) { func (i *Interaction) processCharacter(char byte) {
if i.InteractiveMode { if i.interactiveMode {
i.handleInteractiveMode(char) i.handleInteractiveMode(char)
return return
} }
@@ -113,7 +129,7 @@ func (i *Interaction) processCharacter(char byte) {
} }
func (i *Interaction) handleInteractiveMode(char byte) { func (i *Interaction) handleInteractiveMode(char byte) {
switch i.InteractionType { switch i.interactionType {
case types.Slug: case types.Slug:
i.HandleSlugEditMode(char) i.HandleSlugEditMode(char)
} }
@@ -123,7 +139,7 @@ func (i *Interaction) handleExitSequence(char byte) bool {
if char == ctrlC { if char == ctrlC {
if i.pendingExit { if i.pendingExit {
i.SendMessage("Closing connection...\r\n") i.SendMessage("Closing connection...\r\n")
if err := i.Lifecycle.Close(); err != nil { if err := i.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return true return true
@@ -147,37 +163,37 @@ func (i *Interaction) handleNonInteractiveInput(char byte) {
i.handleBackspace() i.handleBackspace()
case char == forwardSlash: case char == forwardSlash:
i.handleCommandStart() i.handleCommandStart()
case i.CommandBuffer.Len() > 0: case i.commandBuffer.Len() > 0:
i.handleCommandInput(char) i.handleCommandInput(char)
case char == enterChar: case char == enterChar:
i.SendMessage(clearLine) i.SendMessage(clearLine)
default: default:
i.InputLength++ i.inputLength++
} }
} }
func (i *Interaction) handleBackspace() { func (i *Interaction) handleBackspace() {
if i.InputLength > 0 { if i.inputLength > 0 {
i.SendMessage(backspaceSeq) i.SendMessage(backspaceSeq)
} }
if i.CommandBuffer.Len() > 0 { if i.commandBuffer.Len() > 0 {
i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) i.commandBuffer.Truncate(i.commandBuffer.Len() - 1)
} }
} }
func (i *Interaction) handleCommandStart() { func (i *Interaction) handleCommandStart() {
i.CommandBuffer.Reset() i.commandBuffer.Reset()
i.CommandBuffer.WriteByte(forwardSlash) i.commandBuffer.WriteByte(forwardSlash)
} }
func (i *Interaction) handleCommandInput(char byte) { func (i *Interaction) handleCommandInput(char byte) {
if char == enterChar { if char == enterChar {
i.SendMessage(clearLine) i.SendMessage(clearLine)
i.HandleCommand(i.CommandBuffer.String()) i.HandleCommand(i.commandBuffer.String())
return return
} }
i.CommandBuffer.WriteByte(char) i.commandBuffer.WriteByte(char)
i.InputLength++ i.inputLength++
} }
func (i *Interaction) HandleSlugEditMode(char byte) { func (i *Interaction) HandleSlugEditMode(char byte) {
@@ -194,15 +210,15 @@ func (i *Interaction) HandleSlugEditMode(char byte) {
} }
func (i *Interaction) handleSlugBackspace() { func (i *Interaction) handleSlugBackspace() {
if len(i.EditSlug) > 0 { if len(i.editSlug) > 0 {
i.EditSlug = i.EditSlug[:len(i.EditSlug)-1] i.editSlug = i.editSlug[:len(i.editSlug)-1]
i.refreshSlugDisplay() i.refreshSlugDisplay()
} }
} }
func (i *Interaction) appendToSlug(char byte) { func (i *Interaction) appendToSlug(char byte) {
if isValidSlugChar(char) { if len(i.editSlug) < maxSlugLength {
i.EditSlug += string(char) i.editSlug += string(char)
i.refreshSlugDisplay() i.refreshSlugDisplay()
} }
} }
@@ -210,16 +226,16 @@ func (i *Interaction) appendToSlug(char byte) {
func (i *Interaction) refreshSlugDisplay() { func (i *Interaction) refreshSlugDisplay() {
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
i.SendMessage(clearToLineEnd) i.SendMessage(clearToLineEnd)
i.SendMessage("➤ " + i.EditSlug + "." + domain) i.SendMessage("➤ " + i.editSlug + "." + domain)
} }
func (i *Interaction) HandleSlugSave() { func (i *Interaction) HandleSlugSave() {
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
switch { switch {
case isForbiddenSlug(i.EditSlug): case isForbiddenSlug(i.editSlug):
i.showForbiddenSlugMessage() i.showForbiddenSlugMessage()
case !isValidSlug(i.EditSlug): case !isValidSlug(i.editSlug):
i.showInvalidSlugMessage() i.showInvalidSlugMessage()
default: default:
i.updateSlug() i.updateSlug()
@@ -230,8 +246,8 @@ func (i *Interaction) HandleSlugSave() {
} }
func (i *Interaction) updateSlug() { func (i *Interaction) updateSlug() {
oldSlug := i.SlugManager.Get() oldSlug := i.slugManager.Get()
newSlug := i.EditSlug newSlug := i.editSlug
if !i.updateClientSlug(oldSlug, newSlug) { if !i.updateClientSlug(oldSlug, newSlug) {
i.HandleSlugUpdateError() i.HandleSlugUpdateError()
@@ -262,8 +278,8 @@ func (i *Interaction) returnToMainScreen() {
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
i.ShowWelcomeMessage() i.ShowWelcomeMessage()
i.ShowForwardingMessage() i.ShowForwardingMessage()
i.InteractiveMode = false i.interactiveMode = false
i.CommandBuffer.Reset() i.commandBuffer.Reset()
} }
func (i *Interaction) HandleSlugCancel() { func (i *Interaction) HandleSlugCancel() {
@@ -271,8 +287,8 @@ func (i *Interaction) HandleSlugCancel() {
i.SendMessage("\r\n\r\n⚠ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n") i.SendMessage("\r\n\r\n⚠ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n")
i.SendMessage("Press any key to continue...\r\n") i.SendMessage("Press any key to continue...\r\n")
i.InteractiveMode = false i.interactiveMode = false
i.InteractionType = "" i.interactionType = ""
i.WaitForKeyPress() i.WaitForKeyPress()
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
@@ -289,7 +305,7 @@ func (i *Interaction) HandleSlugUpdateError() {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
if err := i.Lifecycle.Close(); err != nil { if err := i.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
} }
@@ -308,12 +324,12 @@ func (i *Interaction) HandleCommand(command string) {
i.SendMessage("Unknown command\r\n") i.SendMessage("Unknown command\r\n")
} }
i.CommandBuffer.Reset() i.commandBuffer.Reset()
} }
func (i *Interaction) handleByeCommand() { func (i *Interaction) handleByeCommand() {
i.SendMessage("Closing connection...\r\n") i.SendMessage("Closing connection...\r\n")
if err := i.Lifecycle.Close(); err != nil { if err := i.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
} }
@@ -329,32 +345,32 @@ func (i *Interaction) handleClearCommand() {
} }
func (i *Interaction) handleSlugCommand() { func (i *Interaction) handleSlugCommand() {
if i.Forwarder.GetTunnelType() != types.HTTP { if i.forwarder.GetTunnelType() != types.HTTP {
i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains\r\n", i.Forwarder.GetTunnelType())) i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains\r\n", i.forwarder.GetTunnelType()))
return return
} }
i.InteractiveMode = true i.interactiveMode = true
i.InteractionType = types.Slug i.interactionType = types.Slug
i.EditSlug = i.SlugManager.Get() i.editSlug = i.slugManager.Get()
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
i.DisplaySlugEditor() i.DisplaySlugEditor()
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
i.SendMessage("➤ " + i.EditSlug + "." + domain) i.SendMessage("➤ " + i.editSlug + "." + domain)
} }
func (i *Interaction) ShowForwardingMessage() { func (i *Interaction) ShowForwardingMessage() {
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
if i.Forwarder.GetTunnelType() == types.HTTP { if i.forwarder.GetTunnelType() == types.HTTP {
protocol := "http" protocol := "http"
if utils.Getenv("TLS_ENABLED", "false") == "true" { if utils.Getenv("TLS_ENABLED", "false") == "true" {
protocol = "https" protocol = "https"
} }
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.slugManager.Get(), domain))
} else { } else {
i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.forwarder.GetForwardedPort()))
} }
} }
@@ -385,7 +401,7 @@ func (i *Interaction) ShowWelcomeMessage() {
func (i *Interaction) DisplaySlugEditor() { func (i *Interaction) DisplaySlugEditor() {
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
fullDomain := i.SlugManager.Get() + "." + domain fullDomain := i.slugManager.Get() + "." + domain
contentLine := " ║ Current: " + fullDomain contentLine := " ║ Current: " + fullDomain
boxWidth := calculateBoxWidth(contentLine) boxWidth := calculateBoxWidth(contentLine)

View File

@@ -22,16 +22,27 @@ type Forwarder interface {
} }
type Lifecycle struct { type Lifecycle struct {
Status types.Status status types.Status
Conn ssh.Conn conn ssh.Conn
Channel ssh.Channel channel ssh.Channel
interaction Interaction
Interaction Interaction forwarder Forwarder
Forwarder Forwarder slugManager slug.Manager
SlugManager slug.Manager
unregisterClient func(slug string) unregisterClient func(slug string)
} }
func NewLifecycle(conn ssh.Conn, interaction Interaction, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
return &Lifecycle{
status: "",
conn: conn,
channel: nil,
interaction: interaction,
forwarder: forwarder,
slugManager: slugManager,
unregisterClient: nil,
}
}
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
l.unregisterClient = unregisterClient l.unregisterClient = unregisterClient
} }
@@ -46,46 +57,46 @@ type SessionLifecycle interface {
} }
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *Lifecycle) GetChannel() ssh.Channel {
return l.Channel return l.channel
} }
func (l *Lifecycle) SetChannel(channel ssh.Channel) { func (l *Lifecycle) SetChannel(channel ssh.Channel) {
l.Channel = channel l.channel = channel
} }
func (l *Lifecycle) GetConnection() ssh.Conn { func (l *Lifecycle) GetConnection() ssh.Conn {
return l.Conn return l.conn
} }
func (l *Lifecycle) SetStatus(status types.Status) { func (l *Lifecycle) SetStatus(status types.Status) {
l.Status = status l.status = status
} }
func (l *Lifecycle) Close() error { func (l *Lifecycle) Close() error {
err := l.Forwarder.Close() err := l.forwarder.Close()
if err != nil && !errors.Is(err, net.ErrClosed) { if err != nil && !errors.Is(err, net.ErrClosed) {
return err return err
} }
if l.Channel != nil { if l.channel != nil {
err := l.Channel.Close() err := l.channel.Close()
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
return err return err
} }
} }
if l.Conn != nil { if l.conn != nil {
err := l.Conn.Close() err := l.conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) { if err != nil && !errors.Is(err, net.ErrClosed) {
return err return err
} }
} }
clientSlug := l.SlugManager.Get() clientSlug := l.slugManager.Get()
if clientSlug != "" { if clientSlug != "" {
l.unregisterClient(clientSlug) l.unregisterClient(clientSlug)
} }
if l.Forwarder.GetTunnelType() == types.TCP { if l.forwarder.GetTunnelType() == types.TCP {
err := portUtil.Default.SetPortStatus(l.Forwarder.GetForwardedPort(), false) err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,7 +1,6 @@
package session package session
import ( import (
"bytes"
"fmt" "fmt"
"log" "log"
"sync" "sync"
@@ -28,36 +27,33 @@ type Session interface {
} }
type SSHSession struct { type SSHSession struct {
Lifecycle lifecycle.SessionLifecycle lifecycle lifecycle.SessionLifecycle
Interaction interaction.Controller interaction interaction.Controller
Forwarder forwarder.ForwardingController forwarder forwarder.ForwardingController
SlugManager slug.Manager slugManager slug.Manager
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
return s.lifecycle
}
func (s *SSHSession) GetInteraction() interaction.Controller {
return s.interaction
}
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
return s.forwarder
}
func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
} }
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
slugManager := slug.NewManager() slugManager := slug.NewManager()
forwarderManager := &forwarder.Forwarder{ forwarderManager := forwarder.NewForwarder(slugManager)
Listener: nil, interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
TunnelType: "", lifecycleManager := lifecycle.NewLifecycle(conn, interactionManager, forwarderManager, slugManager)
ForwardedPort: 0,
SlugManager: slugManager,
}
interactionManager := &interaction.Interaction{
CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
InteractiveMode: false,
EditSlug: "",
SlugManager: slugManager,
Forwarder: forwarderManager,
Lifecycle: nil,
}
lifecycleManager := &lifecycle.Lifecycle{
Status: "",
Conn: conn,
Channel: nil,
Interaction: interactionManager,
Forwarder: forwarderManager,
SlugManager: slugManager,
}
interactionManager.SetLifecycle(lifecycleManager) interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(updateClientSlug) interactionManager.SetSlugModificator(updateClientSlug)
@@ -65,10 +61,10 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
lifecycleManager.SetUnregisterClient(unregisterClient) lifecycleManager.SetUnregisterClient(unregisterClient)
session := &SSHSession{ session := &SSHSession{
Lifecycle: lifecycleManager, lifecycle: lifecycleManager,
Interaction: interactionManager, interaction: interactionManager,
Forwarder: forwarderManager, forwarder: forwarderManager,
SlugManager: slugManager, slugManager: slugManager,
} }
var once sync.Once var once sync.Once
@@ -79,13 +75,13 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
continue continue
} }
once.Do(func() { once.Do(func() {
session.Lifecycle.SetChannel(ch) session.lifecycle.SetChannel(ch)
session.Interaction.SetChannel(ch) session.interaction.SetChannel(ch)
tcpipReq := session.waitForTCPIPForward(forwardingReq) tcpipReq := session.waitForTCPIPForward(forwardingReq)
if tcpipReq == nil { if tcpipReq == nil {
session.Interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200"))) session.interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200")))
if err := session.Lifecycle.Close(); err != nil { if err := session.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return return
@@ -94,7 +90,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
}) })
go session.HandleGlobalRequest(reqs) go session.HandleGlobalRequest(reqs)
} }
if err := session.Lifecycle.Close(); err != nil { if err := session.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
} }
@@ -134,7 +130,7 @@ func updateClientSlug(oldSlug, newSlug string) bool {
} }
delete(Clients, oldSlug) delete(Clients, oldSlug)
client.SlugManager.Set(newSlug) client.slugManager.Set(newSlug)
Clients[newSlug] = client Clients[newSlug] = client
return true return true
} }