feat: add droping conn command
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 4m38s

This commit is contained in:
2025-12-07 15:26:37 +07:00
parent 8c8fdf251d
commit ba5f702e36
4 changed files with 193 additions and 163 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"log" "log"
"net" "net"
@ -15,11 +16,12 @@ import (
) )
type Forwarder struct { type Forwarder struct {
Listener net.Listener Listener net.Listener
TunnelType types.TunnelType TunnelType types.TunnelType
ForwardedPort uint16 ForwardedPort uint16
SlugManager slug.Manager SlugManager slug.Manager
Lifecycle Lifecycle Lifecycle Lifecycle
ActiveForwarder []chan struct{}
} }
type Lifecycle interface { type Lifecycle interface {
@ -39,6 +41,27 @@ type ForwardingController interface {
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
AddActiveForwarder(drop chan struct{})
DropAllForwarder() int
GetForwarderCount() int
}
func (f *Forwarder) AddActiveForwarder(drop chan struct{}) {
f.ActiveForwarder = append(f.ActiveForwarder, drop)
}
func (f *Forwarder) DropAllForwarder() int {
total := 0
for _, d := range f.ActiveForwarder {
close(d)
total += 1
}
f.ActiveForwarder = nil
return total
}
func (f *Forwarder) GetForwarderCount() int {
return len(f.ActiveForwarder)
} }
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
@ -76,6 +99,7 @@ func (f *Forwarder) AcceptTCPConnections() {
} }
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
drop := make(chan struct{})
defer func(src ssh.Channel) { defer func(src ssh.Channel) {
_, err := io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil { if err != nil {
@ -96,6 +120,16 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
} }
}() }()
go func() {
select {
case <-drop:
fmt.Println("Closinggggg")
return
}
}()
f.AddActiveForwarder(drop)
_, err := io.Copy(dst, src) _, err := io.Copy(dst, src)
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {

View File

@ -26,27 +26,32 @@ type Controller interface {
SendMessage(message string) SendMessage(message string)
HandleUserInput() HandleUserInput()
HandleCommand(command string) HandleCommand(command string)
HandleSlugEditMode(connection ssh.Channel, char byte) HandleSlugEditMode(char byte)
HandleSlugSave(conn ssh.Channel) HandleSlugSave()
HandleSlugCancel(connection ssh.Channel) HandleSlugCancel()
HandleSlugUpdateError() HandleSlugUpdateError()
ShowWelcomeMessage() ShowWelcomeMessage()
DisplaySlugEditor() DisplaySlugEditor()
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
SetSlugModificator(func(oldSlug, newSlug string) bool) SetSlugModificator(func(oldSlug, newSlug string) bool)
WaitForKeyPress()
ShowForwardingMessage()
} }
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType GetTunnelType() types.TunnelType
GetForwardedPort() uint16 GetForwardedPort() uint16
DropAllForwarder() int
GetForwarderCount() int
} }
type Interaction struct { type Interaction struct {
InputLength int InputLength int
CommandBuffer *bytes.Buffer CommandBuffer *bytes.Buffer
EditMode bool InteractiveMode bool
InteractionType types.InteractionType
EditSlug string EditSlug string
channel ssh.Channel channel ssh.Channel
SlugManager slug.Manager SlugManager slug.Manager
@ -76,8 +81,7 @@ func (i *Interaction) SendMessage(message string) {
func (i *Interaction) HandleUserInput() { func (i *Interaction) HandleUserInput() {
buf := make([]byte, 1) buf := make([]byte, 1)
i.EditMode = false i.InteractiveMode = false
for { for {
n, err := i.channel.Read(buf) n, err := i.channel.Read(buf)
if err != nil { if err != nil {
@ -89,9 +93,12 @@ func (i *Interaction) HandleUserInput() {
if n > 0 { if n > 0 {
char := buf[0] char := buf[0]
if i.InteractiveMode {
if i.EditMode { if i.InteractionType == types.Slug {
i.HandleSlugEditMode(i.channel, char) i.HandleSlugEditMode(char)
} else if i.InteractionType == types.Drop {
i.HandleDropMode(char)
}
continue continue
} }
@ -148,55 +155,34 @@ func (i *Interaction) HandleUserInput() {
if char == 13 { if char == 13 {
i.SendMessage("\033[K") i.SendMessage("\033[K")
} }
} }
} }
} }
func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) { func (i *Interaction) HandleSlugEditMode(char byte) {
if char == 13 { if char == 13 {
i.HandleSlugSave(connection) i.HandleSlugSave()
} else if char == 27 || char == 3 { } else if char == 27 || char == 3 {
i.HandleSlugCancel(connection) i.HandleSlugCancel()
} else if char == 8 || char == 127 { } else if char == 8 || char == 127 {
if len(i.EditSlug) > 0 { if len(i.EditSlug) > 0 {
i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1] i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1]
_, err := connection.Write([]byte("\r\033[K")) i.SendMessage("\r\033[K")
if err != nil { i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain"))
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("➤ " + i.EditSlug + "." + utils.Getenv("domain")))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
} }
} else if char >= 32 && char <= 126 { } else if char >= 32 && char <= 126 {
if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' { if (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9') || char == '-' {
i.EditSlug += string(char) i.EditSlug += string(char)
_, err := connection.Write([]byte("\r\033[K")) i.SendMessage("\r\033[K")
if err != nil { i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain"))
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("➤ " + i.EditSlug + "." + utils.Getenv("domain")))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
} }
} }
} }
func (i *Interaction) HandleSlugSave(connection ssh.Channel) { func (i *Interaction) HandleSlugSave() {
isValid := isValidSlug(i.EditSlug) isValid := isValidSlug(i.EditSlug)
_, err := connection.Write([]byte("\033[H\033[2J")) i.SendMessage("\033[H\033[2J")
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
if isValid { if isValid {
oldSlug := i.SlugManager.Get() oldSlug := i.SlugManager.Get()
newSlug := i.EditSlug newSlug := i.EditSlug
@ -206,72 +192,23 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
return return
} }
_, err := connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) i.SendMessage("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")
if err != nil { i.SendMessage("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n")
log.Printf("failed to write to channel: %v", err) i.SendMessage("Press any key to continue...\r\n")
return
}
_, err = connection.Write([]byte("Your new address is: " + newSlug + "." + utils.Getenv("domain") + "\r\n\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("Press any key to continue...\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
} else if isForbiddenSlug(i.EditSlug) { } else if isForbiddenSlug(i.EditSlug) {
_, err := connection.Write([]byte("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n")) i.SendMessage("\r\n\r\n❌ FORBIDDEN SUBDOMAIN ❌\r\n\r\n")
if err != nil { i.SendMessage("This subdomain is not allowed.\r\n")
log.Printf("failed to write to channel: %v", err) i.SendMessage("Please try a different subdomain.\r\n\r\n")
return i.SendMessage("Press any key to continue...\r\n")
}
_, err = connection.Write([]byte("This subdomain is not allowed.\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("Please try a different subdomain.\r\n\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("Press any key to continue...\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
} else { } else {
_, err := connection.Write([]byte("\r\n\r\n❌ INVALID SUBDOMAIN ❌\r\n\r\n")) i.SendMessage("\r\n\r\n❌ INVALID SUBDOMAIN ❌\r\n\r\n")
if err != nil { i.SendMessage("Use only lowercase letters, numbers, and hyphens.\r\n")
log.Printf("failed to write to channel: %v", err) i.SendMessage("Length must be 3-20 characters and cannot start or end with a hyphen.\r\n\r\n")
return i.SendMessage("Press any key to continue...\r\n")
}
_, err = connection.Write([]byte("Use only lowercase letters, numbers, and hyphens.\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("Length must be 3-20 characters and cannot start or end with a hyphen.\r\n\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("Press any key to continue...\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
} }
waitForKeyPress(connection) i.WaitForKeyPress()
i.SendMessage("\033[H\033[2J")
_, err = connection.Write([]byte("\033[H\033[2J"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
i.ShowWelcomeMessage() i.ShowWelcomeMessage()
domain := utils.Getenv("domain") domain := utils.Getenv("domain")
@ -279,43 +216,23 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
if utils.Getenv("tls_enabled") == "true" { if utils.Getenv("tls_enabled") == "true" {
protocol = "https" protocol = "https"
} }
_, err = connection.Write([]byte(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))) i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
i.EditMode = false i.InteractiveMode = false
i.CommandBuffer.Reset() i.CommandBuffer.Reset()
} }
func (i *Interaction) HandleSlugCancel(connection ssh.Channel) { func (i *Interaction) HandleSlugCancel() {
i.EditMode = false i.InteractiveMode = false
_, err := connection.Write([]byte("\033[H\033[2J")) i.SendMessage("\033[H\033[2J")
if err != nil { i.SendMessage("\r\n\r\n⚠ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n")
log.Printf("failed to write to channel: %v", err) i.SendMessage("Press any key to continue...\r\n")
return
}
_, err = connection.Write([]byte("\r\n\r\n⚠ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
_, err = connection.Write([]byte("Press any key to continue...\r\n"))
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
waitForKeyPress(connection) i.WaitForKeyPress()
_, err = connection.Write([]byte("\033[H\033[2J")) i.SendMessage("\033[H\033[2J")
if err != nil {
log.Printf("failed to write to channel: %v", err)
return
}
i.ShowWelcomeMessage() i.ShowWelcomeMessage()
i.ShowForwardingMessage()
i.CommandBuffer.Reset() i.CommandBuffer.Reset()
} }
@ -349,26 +266,23 @@ func (i *Interaction) HandleCommand(command string) {
case "/clear": case "/clear":
i.SendMessage("\033[H\033[2J") i.SendMessage("\033[H\033[2J")
i.ShowWelcomeMessage() i.ShowWelcomeMessage()
domain := utils.Getenv("domain") i.ShowForwardingMessage()
if i.Forwarder.GetTunnelType() == types.HTTP {
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
protocol = "https"
}
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))
} else {
i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort()))
}
case "/slug": case "/slug":
if i.Forwarder.GetTunnelType() != types.HTTP { if i.Forwarder.GetTunnelType() != types.HTTP {
i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType())) i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))
} else { } else {
i.EditMode = true i.InteractiveMode = true
i.InteractionType = types.Slug
i.EditSlug = i.SlugManager.Get() i.EditSlug = i.SlugManager.Get()
i.SendMessage("\033[H\033[2J") i.SendMessage("\033[H\033[2J")
i.DisplaySlugEditor() i.DisplaySlugEditor()
i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain")) i.SendMessage("➤ " + i.EditSlug + "." + utils.Getenv("domain"))
} }
case "/drop":
i.InteractiveMode = true
i.InteractionType = types.Drop
i.SendMessage("\033[H\033[2J")
i.ShowDropMessage()
default: default:
i.SendMessage("Unknown command\r\n") i.SendMessage("Unknown command\r\n")
} }
@ -376,6 +290,80 @@ func (i *Interaction) HandleCommand(command string) {
i.CommandBuffer.Reset() i.CommandBuffer.Reset()
} }
func (i *Interaction) ShowForwardingMessage() {
domain := utils.Getenv("domain")
if i.Forwarder.GetTunnelType() == types.HTTP {
protocol := "http"
if utils.Getenv("tls_enabled") == "true" {
protocol = "https"
}
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain))
} else {
i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort()))
}
}
func (i *Interaction) HandleDropMode(char byte) {
if char == 13 || char == 121 || char == 89 {
count := i.Forwarder.DropAllForwarder()
i.SendMessage("\033[H\033[2J")
i.SendMessage(fmt.Sprintf("Dropped %d forwarders\r\n", count))
i.SendMessage("Press any key to continue...\r\n")
i.InteractiveMode = false
i.InteractionType = ""
i.WaitForKeyPress()
i.SendMessage("\033[H\033[2J")
i.ShowWelcomeMessage()
i.ShowForwardingMessage()
} else if char == 27 || char == 110 || char == 78 || char == 3 {
i.SendMessage("\033[H\033[2J")
i.SendMessage(fmt.Sprintf("Dropping canceled.\r\n"))
i.SendMessage("Press any key to continue...\r\n")
i.InteractiveMode = false
i.InteractionType = ""
i.WaitForKeyPress()
i.SendMessage("\033[H\033[2J")
i.ShowWelcomeMessage()
i.ShowForwardingMessage()
}
}
func (i *Interaction) ShowDropMessage() {
const paddingRight = 4
confirmText := fmt.Sprintf(" ║ Drop ALL %d active connections?", i.Forwarder.GetForwarderCount())
boxWidth := len(confirmText) + paddingRight + 1
if boxWidth < 50 {
boxWidth = 50
}
topBorder := " ╔" + strings.Repeat("═", boxWidth-4) + "╗\r\n"
title := centerText("DROP CONFIRMATION", boxWidth-4)
header := " ║" + title + "║\r\n"
midBorder := " ╠" + strings.Repeat("═", boxWidth-4) + "╣\r\n"
emptyLine := " ║" + strings.Repeat(" ", boxWidth-4) + "║\r\n"
confirmLine := confirmText + strings.Repeat(" ", boxWidth-len(confirmText)+1) + "║\r\n"
controlText := " ║ [Enter/Y] Confirm [N/Esc] Cancel"
controlLine := controlText + strings.Repeat(" ", boxWidth-len(controlText)+1) + "║\r\n"
bottomBorder := " ╚" + strings.Repeat("═", boxWidth-4) + "╝\r\n"
asciiArt := topBorder +
header +
midBorder +
emptyLine +
confirmLine +
emptyLine +
controlLine +
emptyLine +
bottomBorder
i.SendMessage("\r\n" + asciiArt)
i.SendMessage("\r\n\r\n")
}
func (i *Interaction) ShowWelcomeMessage() { func (i *Interaction) ShowWelcomeMessage() {
asciiArt := []string{ asciiArt := []string{
` _______ _ _____ _ `, ` _______ _ _____ _ `,
@ -393,6 +381,7 @@ func (i *Interaction) ShowWelcomeMessage() {
` - '/help' : Show this help message`, ` - '/help' : Show this help message`,
` - '/clear' : Clear the current line`, ` - '/clear' : Clear the current line`,
` - '/slug' : Set custom subdomain`, ` - '/slug' : Set custom subdomain`,
` - '/drop' : Drop all active forwarders`,
} }
for _, line := range asciiArt { for _, line := range asciiArt {
@ -443,6 +432,16 @@ func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug strin
i.updateClientSlug = modificator i.updateClientSlug = modificator
} }
func (i *Interaction) WaitForKeyPress() {
keyBuf := make([]byte, 1)
for {
_, err := i.channel.Read(keyBuf)
if err == nil {
break
}
}
}
func centerText(text string, width int) string { func centerText(text string, width int) string {
padding := (width - len(text)) / 2 padding := (width - len(text)) / 2
if padding < 0 { if padding < 0 {
@ -469,16 +468,6 @@ func isValidSlug(slug string) bool {
return true return true
} }
func waitForKeyPress(connection ssh.Channel) {
keyBuf := make([]byte, 1)
for {
_, err := connection.Read(keyBuf)
if err == nil {
break
}
}
}
func isForbiddenSlug(slug string) bool { func isForbiddenSlug(slug string) bool {
for _, s := range forbiddenSlug { for _, s := range forbiddenSlug {
if slug == s { if slug == s {

View File

@ -43,12 +43,12 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
SlugManager: slugManager, SlugManager: slugManager,
} }
interactionManager := &interaction.Interaction{ interactionManager := &interaction.Interaction{
CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
EditMode: false, InteractiveMode: false,
EditSlug: "", EditSlug: "",
SlugManager: slugManager, SlugManager: slugManager,
Forwarder: forwarderManager, Forwarder: forwarderManager,
Lifecycle: nil, Lifecycle: nil,
} }
lifecycleManager := &lifecycle.Lifecycle{ lifecycleManager := &lifecycle.Lifecycle{
Status: "", Status: "",

View File

@ -15,6 +15,13 @@ const (
TCP TunnelType = "TCP" TCP TunnelType = "TCP"
) )
type InteractionType string
const (
Slug InteractionType = "SLUG"
Drop InteractionType = "DROP"
)
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" + "Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" + "Content-Type: text/plain\r\n\r\n" +