From ba5f702e36591167bb9275dfe151c27776260bbf Mon Sep 17 00:00:00 2001 From: bagas Date: Sun, 7 Dec 2025 15:26:37 +0700 Subject: [PATCH] feat: add droping conn command --- session/forwarder/forwarder.go | 44 ++++- session/interaction/interaction.go | 293 ++++++++++++++--------------- session/session.go | 12 +- types/types.go | 7 + 4 files changed, 193 insertions(+), 163 deletions(-) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 3d846e6..462df4b 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "log" "net" @@ -15,11 +16,12 @@ import ( ) type Forwarder struct { - Listener net.Listener - TunnelType types.TunnelType - ForwardedPort uint16 - SlugManager slug.Manager - Lifecycle Lifecycle + Listener net.Listener + TunnelType types.TunnelType + ForwardedPort uint16 + SlugManager slug.Manager + Lifecycle Lifecycle + ActiveForwarder []chan struct{} } type Lifecycle interface { @@ -39,6 +41,27 @@ type ForwardingController interface { SetLifecycle(lifecycle Lifecycle) CreateForwardedTCPIPPayload(origin net.Addr) []byte WriteBadGatewayResponse(dst io.Writer) + AddActiveForwarder(drop chan struct{}) + DropAllForwarder() int + GetForwarderCount() int +} + +func (f *Forwarder) AddActiveForwarder(drop chan struct{}) { + f.ActiveForwarder = append(f.ActiveForwarder, drop) +} + +func (f *Forwarder) DropAllForwarder() int { + total := 0 + for _, d := range f.ActiveForwarder { + close(d) + total += 1 + } + f.ActiveForwarder = nil + return total +} + +func (f *Forwarder) GetForwarderCount() int { + return len(f.ActiveForwarder) } func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { @@ -76,6 +99,7 @@ func (f *Forwarder) AcceptTCPConnections() { } func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { + drop := make(chan struct{}) defer func(src ssh.Channel) { _, err := io.Copy(io.Discard, src) if err != nil { @@ -96,6 +120,16 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA } }() + go func() { + select { + case <-drop: + fmt.Println("Closinggggg") + return + } + }() + + f.AddActiveForwarder(drop) + _, err := io.Copy(dst, src) if err != nil && !errors.Is(err, io.EOF) { diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 049608e..d427978 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -26,27 +26,32 @@ type Controller interface { SendMessage(message string) HandleUserInput() HandleCommand(command string) - HandleSlugEditMode(connection ssh.Channel, char byte) - HandleSlugSave(conn ssh.Channel) - HandleSlugCancel(connection ssh.Channel) + HandleSlugEditMode(char byte) + HandleSlugSave() + HandleSlugCancel() HandleSlugUpdateError() ShowWelcomeMessage() DisplaySlugEditor() SetChannel(channel ssh.Channel) SetLifecycle(lifecycle Lifecycle) SetSlugModificator(func(oldSlug, newSlug string) bool) + WaitForKeyPress() + ShowForwardingMessage() } type Forwarder interface { Close() error GetTunnelType() types.TunnelType GetForwardedPort() uint16 + DropAllForwarder() int + GetForwarderCount() int } type Interaction struct { InputLength int CommandBuffer *bytes.Buffer - EditMode bool + InteractiveMode bool + InteractionType types.InteractionType EditSlug string channel ssh.Channel SlugManager slug.Manager @@ -76,8 +81,7 @@ func (i *Interaction) SendMessage(message string) { func (i *Interaction) HandleUserInput() { buf := make([]byte, 1) - i.EditMode = false - + i.InteractiveMode = false for { n, err := i.channel.Read(buf) if err != nil { @@ -89,9 +93,12 @@ func (i *Interaction) HandleUserInput() { if n > 0 { char := buf[0] - - if i.EditMode { - i.HandleSlugEditMode(i.channel, char) + if i.InteractiveMode { + if i.InteractionType == types.Slug { + i.HandleSlugEditMode(char) + } else if i.InteractionType == types.Drop { + i.HandleDropMode(char) + } continue } @@ -148,55 +155,34 @@ func (i *Interaction) HandleUserInput() { if char == 13 { i.SendMessage("\033[K") } - } } } -func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) { +func (i *Interaction) HandleSlugEditMode(char byte) { if char == 13 { - i.HandleSlugSave(connection) + i.HandleSlugSave() } else if char == 27 || char == 3 { - i.HandleSlugCancel(connection) + i.HandleSlugCancel() } else if char == 8 || char == 127 { if len(i.EditSlug) > 0 { i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1] - _, err := connection.Write([]byte("\r\033[K")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("➤ " + i.EditSlug + "." + utils.Getenv("domain"))) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage("\r\033[K") + i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain")) } } else if char >= 32 && char <= 126 { if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' { i.EditSlug += string(char) - _, err := connection.Write([]byte("\r\033[K")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("➤ " + i.EditSlug + "." + utils.Getenv("domain"))) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage("\r\033[K") + i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain")) } } } -func (i *Interaction) HandleSlugSave(connection ssh.Channel) { +func (i *Interaction) HandleSlugSave() { isValid := isValidSlug(i.EditSlug) - _, err := connection.Write([]byte("\033[H\033[2J")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage("\033[H\033[2J") if isValid { oldSlug := i.SlugManager.Get() newSlug := i.EditSlug @@ -206,72 +192,23 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { return } - _, err := connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Press any key to continue...\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n") + i.SendMessage("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n") + i.SendMessage("Press any key to continue...\r\n") } else if isForbiddenSlug(i.EditSlug) { - _, err := connection.Write([]byte("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("This subdomain is not allowed.\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Please try a different subdomain.\r\n\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Press any key to continue...\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n") + i.SendMessage("This subdomain is not allowed.\r\n") + i.SendMessage("Please try a different subdomain.\r\n\r\n") + i.SendMessage("Press any key to continue...\r\n") } else { - _, err := connection.Write([]byte("\r\n\r\n❌ INVALID SUBDOMAIN ❌\r\n\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Use only lowercase letters, numbers, and hyphens.\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Length must be 3-20 characters and cannot start or end with a hyphen.\r\n\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Press any key to continue...\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage("\r\n\r\n❌ INVALID SUBDOMAIN ❌\r\n\r\n") + i.SendMessage("Use only lowercase letters, numbers, and hyphens.\r\n") + i.SendMessage("Length must be 3-20 characters and cannot start or end with a hyphen.\r\n\r\n") + i.SendMessage("Press any key to continue...\r\n") } - waitForKeyPress(connection) - - _, err = connection.Write([]byte("\033[H\033[2J")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.WaitForKeyPress() + i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() domain := utils.Getenv("domain") @@ -279,43 +216,23 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) { if utils.Getenv("tls_enabled") == "true" { protocol = "https" } - _, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) - i.EditMode = false + i.InteractiveMode = false i.CommandBuffer.Reset() } -func (i *Interaction) HandleSlugCancel(connection ssh.Channel) { - i.EditMode = false - _, err := connection.Write([]byte("\033[H\033[2J")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("\r\n\r\n⚠️ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } - _, err = connection.Write([]byte("Press any key to continue...\r\n")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } +func (i *Interaction) HandleSlugCancel() { + i.InteractiveMode = false + i.SendMessage("\033[H\033[2J") + i.SendMessage("\r\n\r\n⚠️ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n") + i.SendMessage("Press any key to continue...\r\n") - waitForKeyPress(connection) + i.WaitForKeyPress() - _, err = connection.Write([]byte("\033[H\033[2J")) - if err != nil { - log.Printf("failed to write to channel: %v", err) - return - } + i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() - + i.ShowForwardingMessage() i.CommandBuffer.Reset() } @@ -349,26 +266,23 @@ func (i *Interaction) HandleCommand(command string) { case "/clear": i.SendMessage("\033[H\033[2J") i.ShowWelcomeMessage() - domain := utils.Getenv("domain") - if i.Forwarder.GetTunnelType() == types.HTTP { - protocol := "http" - if utils.Getenv("tls_enabled") == "true" { - protocol = "https" - } - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) - } else { - i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) - } + i.ShowForwardingMessage() case "/slug": if i.Forwarder.GetTunnelType() != types.HTTP { i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType())) } else { - i.EditMode = true + i.InteractiveMode = true + i.InteractionType = types.Slug i.EditSlug = i.SlugManager.Get() i.SendMessage("\033[H\033[2J") i.DisplaySlugEditor() i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain")) } + case "/drop": + i.InteractiveMode = true + i.InteractionType = types.Drop + i.SendMessage("\033[H\033[2J") + i.ShowDropMessage() default: i.SendMessage("Unknown command\r\n") } @@ -376,6 +290,80 @@ func (i *Interaction) HandleCommand(command string) { i.CommandBuffer.Reset() } +func (i *Interaction) ShowForwardingMessage() { + domain := utils.Getenv("domain") + if i.Forwarder.GetTunnelType() == types.HTTP { + protocol := "http" + if utils.Getenv("tls_enabled") == "true" { + protocol = "https" + } + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) + } else { + i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) + } +} + +func (i *Interaction) HandleDropMode(char byte) { + if char == 13 || char == 121 || char == 89 { + count := i.Forwarder.DropAllForwarder() + i.SendMessage("\033[H\033[2J") + i.SendMessage(fmt.Sprintf("Dropped %d forwarders\r\n", count)) + i.SendMessage("Press any key to continue...\r\n") + i.InteractiveMode = false + i.InteractionType = "" + i.WaitForKeyPress() + i.SendMessage("\033[H\033[2J") + i.ShowWelcomeMessage() + i.ShowForwardingMessage() + } else if char == 27 || char == 110 || char == 78 || char == 3 { + i.SendMessage("\033[H\033[2J") + i.SendMessage(fmt.Sprintf("Dropping canceled.\r\n")) + i.SendMessage("Press any key to continue...\r\n") + i.InteractiveMode = false + i.InteractionType = "" + i.WaitForKeyPress() + i.SendMessage("\033[H\033[2J") + i.ShowWelcomeMessage() + i.ShowForwardingMessage() + } +} + +func (i *Interaction) ShowDropMessage() { + const paddingRight = 4 + + confirmText := fmt.Sprintf(" ║ Drop ALL %d active connections?", i.Forwarder.GetForwarderCount()) + boxWidth := len(confirmText) + paddingRight + 1 + if boxWidth < 50 { + boxWidth = 50 + } + + topBorder := " ╔" + strings.Repeat("═", boxWidth-4) + "╗\r\n" + title := centerText("DROP CONFIRMATION", boxWidth-4) + header := " ║" + title + "║\r\n" + midBorder := " ╠" + strings.Repeat("═", boxWidth-4) + "╣\r\n" + emptyLine := " ║" + strings.Repeat(" ", boxWidth-4) + "║\r\n" + + confirmLine := confirmText + strings.Repeat(" ", boxWidth-len(confirmText)+1) + "║\r\n" + + controlText := " ║ [Enter/Y] Confirm [N/Esc] Cancel" + controlLine := controlText + strings.Repeat(" ", boxWidth-len(controlText)+1) + "║\r\n" + + bottomBorder := " ╚" + strings.Repeat("═", boxWidth-4) + "╝\r\n" + + asciiArt := topBorder + + header + + midBorder + + emptyLine + + confirmLine + + emptyLine + + controlLine + + emptyLine + + bottomBorder + + i.SendMessage("\r\n" + asciiArt) + i.SendMessage("\r\n\r\n") +} + func (i *Interaction) ShowWelcomeMessage() { asciiArt := []string{ ` _______ _ _____ _ `, @@ -393,6 +381,7 @@ func (i *Interaction) ShowWelcomeMessage() { ` - '/help' : Show this help message`, ` - '/clear' : Clear the current line`, ` - '/slug' : Set custom subdomain`, + ` - '/drop' : Drop all active forwarders`, } for _, line := range asciiArt { @@ -443,6 +432,16 @@ func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug strin i.updateClientSlug = modificator } +func (i *Interaction) WaitForKeyPress() { + keyBuf := make([]byte, 1) + for { + _, err := i.channel.Read(keyBuf) + if err == nil { + break + } + } +} + func centerText(text string, width int) string { padding := (width - len(text)) / 2 if padding < 0 { @@ -469,16 +468,6 @@ func isValidSlug(slug string) bool { return true } -func waitForKeyPress(connection ssh.Channel) { - keyBuf := make([]byte, 1) - for { - _, err := connection.Read(keyBuf) - if err == nil { - break - } - } -} - func isForbiddenSlug(slug string) bool { for _, s := range forbiddenSlug { if slug == s { diff --git a/session/session.go b/session/session.go index aabbe63..6c70ec5 100644 --- a/session/session.go +++ b/session/session.go @@ -43,12 +43,12 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan SlugManager: slugManager, } interactionManager := &interaction.Interaction{ - CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), - EditMode: false, - EditSlug: "", - SlugManager: slugManager, - Forwarder: forwarderManager, - Lifecycle: nil, + CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), + InteractiveMode: false, + EditSlug: "", + SlugManager: slugManager, + Forwarder: forwarderManager, + Lifecycle: nil, } lifecycleManager := &lifecycle.Lifecycle{ Status: "", diff --git a/types/types.go b/types/types.go index f909da5..1ad818c 100644 --- a/types/types.go +++ b/types/types.go @@ -15,6 +15,13 @@ const ( TCP TunnelType = "TCP" ) +type InteractionType string + +const ( + Slug InteractionType = "SLUG" + Drop InteractionType = "DROP" +) + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" +