refactor: restructure session initialization to avoid circular references

This commit is contained in:
2025-12-04 22:48:15 +07:00
parent 039e979142
commit 7a31047bb9
7 changed files with 229 additions and 219 deletions

View File

@ -30,13 +30,13 @@ type CustomWriter struct {
buf []byte buf []byte
respHeader *ResponseHeaderFactory respHeader *ResponseHeaderFactory
reqHeader *RequestHeaderFactory reqHeader *RequestHeaderFactory
interaction interaction.InteractionController interaction interaction.Controller
respMW []ResponseMiddleware respMW []ResponseMiddleware
reqStartMW []RequestMiddleware reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware reqEndMW []RequestMiddleware
} }
func (cw *CustomWriter) SetInteraction(interaction interaction.InteractionController) { func (cw *CustomWriter) SetInteraction(interaction interaction.Controller) {
cw.interaction = interaction cw.interaction = interaction
} }
@ -350,7 +350,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
} }
} }
sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr) sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr)
return return
} }

View File

@ -29,11 +29,11 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [
} }
type RequestLogger struct { type RequestLogger struct {
interaction interaction.InteractionController interaction interaction.Controller
remoteAddr net.Addr remoteAddr net.Addr
} }
func NewRequestLogger(interaction interaction.InteractionController, remoteAddr net.Addr) *RequestLogger { func NewRequestLogger(interaction interaction.Controller, remoteAddr net.Addr) *RequestLogger {
return &RequestLogger{ return &RequestLogger{
interaction: interaction, interaction: interaction,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,

View File

@ -1,9 +1,17 @@
package forwarder package forwarder
import ( import (
"bytes"
"encoding/binary"
"errors"
"io"
"log"
"net" "net"
"strconv"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
"golang.org/x/crypto/ssh"
) )
type Forwarder struct { type Forwarder struct {
@ -11,14 +19,83 @@ type Forwarder struct {
TunnelType types.TunnelType TunnelType types.TunnelType
ForwardedPort uint16 ForwardedPort uint16
SlugManager slug.Manager SlugManager slug.Manager
Lifecycle Lifecycle
}
type Lifecycle interface {
GetConnection() ssh.Conn
}
type ForwardingController interface {
AcceptTCPConnections()
SetType(tunnelType types.TunnelType)
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
GetListener() net.Listener
Close() error
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
}
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
f.Lifecycle = lifecycle
} }
func (f *Forwarder) AcceptTCPConnections() { func (f *Forwarder) AcceptTCPConnections() {
panic("implement me") for {
conn, err := f.GetListener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort())
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go func() {
for req := range reqs {
err := req.Reply(false, nil)
if err != nil {
log.Printf("Failed to reply to request: %v", err)
return
}
}
}()
go f.HandleConnection(conn, channel, conn.RemoteAddr())
}
} }
func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
panic("implement me") defer func(src ssh.Channel) {
err := src.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing connection: %v", err)
}
}(src)
log.Printf("Handling new forwarded connection from %s", remoteAddr)
go func() {
_, err := io.Copy(src, dst)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying from conn.Reader to channel: %v", err)
}
}()
_, err := io.Copy(dst, src)
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error copying from channel to conn.Writer: %v", err)
}
return
} }
func (f *Forwarder) SetType(tunnelType types.TunnelType) { func (f *Forwarder) SetType(tunnelType types.TunnelType) {
@ -52,33 +129,39 @@ func (f *Forwarder) Close() error {
return nil return nil
} }
type ForwardingController interface { func ParseAddr(addr string) (string, uint32) {
AcceptTCPConnections() host, portStr, err := net.SplitHostPort(addr)
UpdateClientSlug(oldSlug, newSlug string) bool if err != nil {
SetType(tunnelType types.TunnelType) log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
GetTunnelType() types.TunnelType return "0.0.0.0", uint32(0)
GetForwardedPort() uint16 }
SetForwardedPort(port uint16) port, _ := strconv.Atoi(portStr)
SetListener(listener net.Listener) return host, uint32(port)
GetListener() net.Listener }
Close() error func writeSSHString(buffer *bytes.Buffer, str string) {
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
}
buffer.WriteString(str)
} }
//func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool { func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
// session.clientsMutex.Lock() var buf bytes.Buffer
// defer session.clientsMutex.Unlock()
// writeSSHString(&buf, "localhost")
// if _, exists := session.Clients[newSlug]; exists && newSlug != oldSlug { err := binary.Write(&buf, binary.BigEndian, uint32(port))
// return false if err != nil {
// } log.Printf("Failed to write string to buffer: %v", err)
// return nil
// client, ok := session.Clients[oldSlug] }
// if !ok { writeSSHString(&buf, host)
// return false err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
// } if err != nil {
// log.Printf("Failed to write string to buffer: %v", err)
// delete(session.Clients, oldSlug) return nil
// f.SlugManager.Set(newSlug) }
// session.Clients[newSlug] = client
// return true return buf.Bytes()
//} }

View File

@ -3,12 +3,9 @@ package session
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"strconv"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
"tunnel_pls/types" "tunnel_pls/types"
@ -17,10 +14,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type UserConnection struct { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
Reader io.Reader
Writer net.Conn
}
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest { for req := range GlobalRequest {
@ -157,23 +151,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
s.HandleTCPForward(req, addr, portToBind) s.HandleTCPForward(req, addr, portToBind)
} }
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
return false
}
if port < 1024 && port != 0 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
return true
}
}
return false
}
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
s.Forwarder.SetType(types.HTTP) s.Forwarder.SetType(types.HTTP)
s.Forwarder.SetForwardedPort(portToBind) s.Forwarder.SetForwardedPort(portToBind)
@ -237,7 +214,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
s.Interaction.ShowWelcomeMessage() s.Interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
go s.acceptTCPConnections() go s.Forwarder.AcceptTCPConnections()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
@ -253,37 +230,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
} }
} }
func (s *SSHSession) acceptTCPConnections() {
for {
conn, err := s.Forwarder.GetListener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort())
channel, reqs, err := s.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go func() {
for req := range reqs {
err := req.Reply(false, nil)
if err != nil {
log.Printf("Failed to reply to request: %v", err)
return
}
}
}()
go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr())
}
}
func generateUniqueSlug() string { func generateUniqueSlug() string {
maxAttempts := 5 maxAttempts := 5
@ -303,30 +249,6 @@ func generateUniqueSlug() string {
return "" return ""
} }
func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func(src ssh.Channel) {
err := src.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing connection: %v", err)
}
}(src)
log.Printf("Handling new forwarded connection from %s", remoteAddr)
go func() {
_, err := io.Copy(src, dst)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying from conn.Reader to channel: %v", err)
}
}()
_, err := io.Copy(dst, src)
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error copying from channel to conn.Writer: %v", err)
}
return
}
func readSSHString(reader *bytes.Reader) (string, error) { func readSSHString(reader *bytes.Reader) (string, error) {
var length uint32 var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil { if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
@ -339,40 +261,17 @@ func readSSHString(reader *bytes.Reader) (string, error) {
return string(strBytes), nil return string(strBytes), nil
} }
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte { func isBlockedPort(port uint16) bool {
var buf bytes.Buffer if port == 80 || port == 443 {
return false
writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(port))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
} }
writeSSHString(&buf, host) if port < 1024 && port != 0 {
err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) return true
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
} }
for _, p := range blockedReservedPorts {
return buf.Bytes() if p == port {
} return true
}
func writeSSHString(buffer *bytes.Buffer, str string) { }
err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) return false
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
}
buffer.WriteString(str)
}
func ParseAddr(addr string) (string, uint32) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint32(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint32(port)
} }

View File

@ -14,22 +14,27 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var forbiddenSlug = []string{
"ping",
}
type Lifecycle interface { type Lifecycle interface {
Close() error Close() error
} }
type InteractionController interface { type Controller interface {
SendMessage(message string) SendMessage(message string)
HandleUserInput() HandleUserInput()
HandleCommand(command string, commandBuffer *bytes.Buffer) HandleCommand(command string)
HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) HandleSlugEditMode(connection ssh.Channel, char byte)
HandleSlugSave(conn ssh.Channel) HandleSlugSave(conn ssh.Channel)
HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) HandleSlugCancel(connection ssh.Channel)
HandleSlugUpdateError() HandleSlugUpdateError()
ShowWelcomeMessage() ShowWelcomeMessage()
DisplaySlugEditor() DisplaySlugEditor()
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
SetSlugModificator(func(oldSlug, newSlug string) bool)
} }
type Forwarder interface { type Forwarder interface {
@ -39,13 +44,14 @@ type Forwarder interface {
} }
type Interaction struct { type Interaction struct {
CommandBuffer *bytes.Buffer CommandBuffer *bytes.Buffer
EditMode bool EditMode bool
EditSlug string EditSlug string
channel ssh.Channel channel ssh.Channel
SlugManager slug.Manager SlugManager slug.Manager
Forwarder Forwarder Forwarder Forwarder
Lifecycle Lifecycle Lifecycle Lifecycle
updateClientSlug func(oldSlug, newSlug string) bool
} }
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
@ -67,7 +73,6 @@ func (i *Interaction) SendMessage(message string) {
} }
func (i *Interaction) HandleUserInput() { func (i *Interaction) HandleUserInput() {
var commandBuffer bytes.Buffer
buf := make([]byte, 1) buf := make([]byte, 1)
i.EditMode = false i.EditMode = false
@ -84,42 +89,42 @@ func (i *Interaction) HandleUserInput() {
char := buf[0] char := buf[0]
if i.EditMode { if i.EditMode {
i.HandleSlugEditMode(i.channel, char, &commandBuffer) i.HandleSlugEditMode(i.channel, char)
continue continue
} }
i.SendMessage(string(buf[:n])) i.SendMessage(string(buf[:n]))
if char == 8 || char == 127 { if char == 8 || char == 127 {
if commandBuffer.Len() > 0 { if i.CommandBuffer.Len() > 0 {
commandBuffer.Truncate(commandBuffer.Len() - 1) i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1)
i.SendMessage("\b \b") i.SendMessage("\b \b")
} }
continue continue
} }
if char == '/' { if char == '/' {
commandBuffer.Reset() i.CommandBuffer.Reset()
commandBuffer.WriteByte(char) i.CommandBuffer.WriteByte(char)
continue continue
} }
if commandBuffer.Len() > 0 { if i.CommandBuffer.Len() > 0 {
if char == 13 { if char == 13 {
i.HandleCommand(commandBuffer.String(), &commandBuffer) i.HandleCommand(i.CommandBuffer.String())
continue continue
} }
commandBuffer.WriteByte(char) i.CommandBuffer.WriteByte(char)
} }
} }
} }
} }
func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) { func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) {
if char == 13 { if char == 13 {
i.HandleSlugSave(connection) i.HandleSlugSave(connection)
} else if char == 27 { } else if char == 27 {
i.HandleSlugCancel(connection, commandBuffer) i.HandleSlugCancel(connection)
} else if char == 8 || char == 127 { } else if char == 8 || char == 127 {
if len(i.EditSlug) > 0 { if len(i.EditSlug) > 0 {
i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1] i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1]
@ -160,13 +165,13 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
return return
} }
if isValid { if isValid {
//oldSlug := i.SlugManager.Get() oldSlug := i.SlugManager.Get()
newSlug := i.EditSlug newSlug := i.EditSlug
//if !i.updateClientSlug(oldSlug, newSlug) { if !i.updateClientSlug(oldSlug, newSlug) {
// i.HandleSlugUpdateError() i.HandleSlugUpdateError()
// return return
//} }
_, err := connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n")) _, err := connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n"))
if err != nil { if err != nil {
@ -251,7 +256,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
i.CommandBuffer.Reset() i.CommandBuffer.Reset()
} }
func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) { func (i *Interaction) HandleSlugCancel(connection ssh.Channel) {
i.EditMode = false i.EditMode = false
_, err := connection.Write([]byte("\033[H\033[2J")) _, err := connection.Write([]byte("\033[H\033[2J"))
if err != nil { if err != nil {
@ -278,7 +283,7 @@ func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *by
} }
i.ShowWelcomeMessage() i.ShowWelcomeMessage()
commandBuffer.Reset() i.CommandBuffer.Reset()
} }
func (i *Interaction) HandleSlugUpdateError() { func (i *Interaction) HandleSlugUpdateError() {
@ -296,7 +301,7 @@ func (i *Interaction) HandleSlugUpdateError() {
} }
} }
func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) { func (i *Interaction) HandleCommand(command string) {
switch command { switch command {
case "/bye": case "/bye":
i.SendMessage("\r\nClosing connection...") i.SendMessage("\r\nClosing connection...")
@ -307,7 +312,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer)
} }
return return
case "/help": case "/help":
i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug") i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug\r\n")
case "/clear": case "/clear":
i.SendMessage("\033[H\033[2J") i.SendMessage("\033[H\033[2J")
i.ShowWelcomeMessage() i.ShowWelcomeMessage()
@ -323,7 +328,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer)
} }
case "/slug": case "/slug":
if i.Forwarder.GetTunnelType() != types.HTTP { if i.Forwarder.GetTunnelType() != types.HTTP {
i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))) i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))
} else { } else {
i.EditMode = true i.EditMode = true
i.EditSlug = i.SlugManager.Get() i.EditSlug = i.SlugManager.Get()
@ -335,7 +340,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer)
i.SendMessage("Unknown command") i.SendMessage("Unknown command")
} }
commandBuffer.Reset() i.CommandBuffer.Reset()
} }
func (i *Interaction) ShowWelcomeMessage() { func (i *Interaction) ShowWelcomeMessage() {
@ -401,6 +406,10 @@ func (i *Interaction) DisplaySlugEditor() {
i.SendMessage("\r\n\r\n") i.SendMessage("\r\n\r\n")
} }
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) bool) {
i.updateClientSlug = modificator
}
func centerText(text string, width int) string { func centerText(text string, width int) string {
padding := (width - len(text)) / 2 padding := (width - len(text)) / 2
if padding < 0 { if padding < 0 {
@ -408,6 +417,7 @@ func centerText(text string, width int) string {
} }
return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding) return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding)
} }
func isValidSlug(slug string) bool { func isValidSlug(slug string) bool {
if len(slug) < 3 || len(slug) > 20 { if len(slug) < 3 || len(slug) > 20 {
return false return false
@ -436,10 +446,6 @@ func waitForKeyPress(connection ssh.Channel) {
} }
} }
var forbiddenSlug = []string{
"ping",
}
func isForbiddenSlug(slug string) bool { func isForbiddenSlug(slug string) bool {
for _, s := range forbiddenSlug { for _, s := range forbiddenSlug {
if slug == s { if slug == s {

View File

@ -7,6 +7,7 @@ import (
"log" "log"
"net" "net"
"time" "time"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
@ -20,6 +21,7 @@ type Interaction interface {
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType GetTunnelType() types.TunnelType
GetForwardedPort() uint16
} }
type Lifecycle struct { type Lifecycle struct {
@ -27,9 +29,14 @@ type Lifecycle struct {
Conn ssh.Conn Conn ssh.Conn
Channel ssh.Channel Channel ssh.Channel
Interaction Interaction Interaction Interaction
Forwarder Forwarder Forwarder Forwarder
SlugManager slug.Manager SlugManager slug.Manager
unregisterClient func(slug string)
}
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
l.unregisterClient = unregisterClient
} }
type SessionLifecycle interface { type SessionLifecycle interface {
@ -39,6 +46,7 @@ type SessionLifecycle interface {
GetConnection() ssh.Conn GetConnection() ssh.Conn
GetChannel() ssh.Channel GetChannel() ssh.Channel
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetUnregisterClient(unregisterClient func(slug string))
} }
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *Lifecycle) GetChannel() ssh.Channel {
@ -84,15 +92,9 @@ func (l *Lifecycle) WaitForRunningStatus() {
func (l *Lifecycle) Close() error { func (l *Lifecycle) Close() error {
err := l.Forwarder.Close() err := l.Forwarder.Close()
if err != nil { if err != nil && !errors.Is(err, net.ErrClosed) {
return err return err
} }
//if s.Forwarder.Listener != nil {
// err := s.Forwarder.Listener.Close()
// if err != nil && !errors.Is(err, net.ErrClosed) {
// return err
// }
//}
if l.Channel != nil { if l.Channel != nil {
err := l.Channel.Close() err := l.Channel.Close()
@ -108,17 +110,17 @@ func (l *Lifecycle) Close() error {
} }
} }
//clientSlug := l.SlugManager.Get() clientSlug := l.SlugManager.Get()
//if clientSlug != "" { if clientSlug != "" {
// unregisterClient(clientSlug) l.unregisterClient(clientSlug)
//} }
//if l.Forwarder.GetType() == "TCP" && s.Forwarder.Listener != nil { if l.Forwarder.GetTunnelType() == types.TCP {
// err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false) err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false)
// if err != nil { if err != nil {
// return err return err
// } }
//} }
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package session package session
import ( import (
"bytes"
"log" "log"
"sync" "sync"
"tunnel_pls/session/forwarder" "tunnel_pls/session/forwarder"
@ -12,11 +13,12 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Session interface { var (
lifecycle.Lifecycle clientsMutex sync.RWMutex
interaction.InteractionController Clients = make(map[string]*SSHSession)
forwarder.ForwardingController )
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request) HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request) HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16) HandleHTTPForward(req *ssh.Request, port uint16)
@ -25,7 +27,7 @@ type Session interface {
type SSHSession struct { type SSHSession struct {
Lifecycle lifecycle.SessionLifecycle Lifecycle lifecycle.SessionLifecycle
Interaction interaction.InteractionController Interaction interaction.Controller
Forwarder forwarder.ForwardingController Forwarder forwarder.ForwardingController
SlugManager slug.Manager SlugManager slug.Manager
} }
@ -39,7 +41,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
SlugManager: slugManager, SlugManager: slugManager,
} }
interactionManager := &interaction.Interaction{ interactionManager := &interaction.Interaction{
CommandBuffer: nil, CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
EditMode: false, EditMode: false,
EditSlug: "", EditSlug: "",
SlugManager: slugManager, SlugManager: slugManager,
@ -54,13 +56,18 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
Forwarder: forwarderManager, Forwarder: forwarderManager,
SlugManager: slugManager, SlugManager: slugManager,
} }
interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(updateClientSlug)
forwarderManager.SetLifecycle(lifecycleManager)
lifecycleManager.SetUnregisterClient(unregisterClient)
session := &SSHSession{ session := &SSHSession{
Lifecycle: lifecycleManager, Lifecycle: lifecycleManager,
Interaction: interactionManager, Interaction: interactionManager,
Forwarder: forwarderManager, Forwarder: forwarderManager,
SlugManager: slugManager, SlugManager: slugManager,
} }
interactionManager.SetLifecycle(lifecycleManager)
go func() { go func() {
go session.Lifecycle.WaitForRunningStatus() go session.Lifecycle.WaitForRunningStatus()
@ -70,7 +77,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
if session.Lifecycle.GetChannel() == nil { if session.Lifecycle.GetChannel() == nil {
session.Lifecycle.SetChannel(ch) session.Lifecycle.SetChannel(ch)
session.Interaction.SetChannel(ch) session.Interaction.SetChannel(ch)
//session.Interaction.channel = ch
session.Lifecycle.SetStatus(types.SETUP) session.Lifecycle.SetStatus(types.SETUP)
go session.HandleGlobalRequest(forwardingReq) go session.HandleGlobalRequest(forwardingReq)
} }
@ -84,10 +90,24 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
}() }()
} }
var ( func updateClientSlug(oldSlug, newSlug string) bool {
clientsMutex sync.RWMutex clientsMutex.Lock()
Clients = make(map[string]*SSHSession) defer clientsMutex.Unlock()
)
if _, exists := Clients[newSlug]; exists && newSlug != oldSlug {
return false
}
client, ok := Clients[oldSlug]
if !ok {
return false
}
delete(Clients, oldSlug)
client.SlugManager.Set(newSlug)
Clients[newSlug] = client
return true
}
func registerClient(slug string, session *SSHSession) bool { func registerClient(slug string, session *SSHSession) bool {
clientsMutex.Lock() clientsMutex.Lock()