refactor: restructure session initialization to avoid circular references
This commit is contained in:
@ -30,13 +30,13 @@ type CustomWriter struct {
|
||||
buf []byte
|
||||
respHeader *ResponseHeaderFactory
|
||||
reqHeader *RequestHeaderFactory
|
||||
interaction interaction.InteractionController
|
||||
interaction interaction.Controller
|
||||
respMW []ResponseMiddleware
|
||||
reqStartMW []RequestMiddleware
|
||||
reqEndMW []RequestMiddleware
|
||||
}
|
||||
|
||||
func (cw *CustomWriter) SetInteraction(interaction interaction.InteractionController) {
|
||||
func (cw *CustomWriter) SetInteraction(interaction interaction.Controller) {
|
||||
cw.interaction = interaction
|
||||
}
|
||||
|
||||
@ -350,7 +350,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
|
||||
}
|
||||
}
|
||||
|
||||
sshSession.HandleForwardedConnection(cw, channel, cw.RemoteAddr)
|
||||
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -29,11 +29,11 @@ func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body [
|
||||
}
|
||||
|
||||
type RequestLogger struct {
|
||||
interaction interaction.InteractionController
|
||||
interaction interaction.Controller
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func NewRequestLogger(interaction interaction.InteractionController, remoteAddr net.Addr) *RequestLogger {
|
||||
func NewRequestLogger(interaction interaction.Controller, remoteAddr net.Addr) *RequestLogger {
|
||||
return &RequestLogger{
|
||||
interaction: interaction,
|
||||
remoteAddr: remoteAddr,
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Forwarder struct {
|
||||
@ -11,14 +19,83 @@ type Forwarder struct {
|
||||
TunnelType types.TunnelType
|
||||
ForwardedPort uint16
|
||||
SlugManager slug.Manager
|
||||
Lifecycle Lifecycle
|
||||
}
|
||||
|
||||
type Lifecycle interface {
|
||||
GetConnection() ssh.Conn
|
||||
}
|
||||
|
||||
type ForwardingController interface {
|
||||
AcceptTCPConnections()
|
||||
SetType(tunnelType types.TunnelType)
|
||||
GetTunnelType() types.TunnelType
|
||||
GetForwardedPort() uint16
|
||||
SetForwardedPort(port uint16)
|
||||
SetListener(listener net.Listener)
|
||||
GetListener() net.Listener
|
||||
Close() error
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||
f.Lifecycle = lifecycle
|
||||
}
|
||||
|
||||
func (f *Forwarder) AcceptTCPConnections() {
|
||||
panic("implement me")
|
||||
for {
|
||||
conn, err := f.GetListener().Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
|
||||
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort())
|
||||
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reply to request: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go f.HandleConnection(conn, channel, conn.RemoteAddr())
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool {
|
||||
panic("implement me")
|
||||
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
defer func(src ssh.Channel) {
|
||||
err := src.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing connection: %v", err)
|
||||
}
|
||||
}(src)
|
||||
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(src, dst)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := io.Copy(dst, src)
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error copying from channel to conn.Writer: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
|
||||
@ -52,33 +129,39 @@ func (f *Forwarder) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type ForwardingController interface {
|
||||
AcceptTCPConnections()
|
||||
UpdateClientSlug(oldSlug, newSlug string) bool
|
||||
SetType(tunnelType types.TunnelType)
|
||||
GetTunnelType() types.TunnelType
|
||||
GetForwardedPort() uint16
|
||||
SetForwardedPort(port uint16)
|
||||
SetListener(listener net.Listener)
|
||||
GetListener() net.Listener
|
||||
Close() error
|
||||
func ParseAddr(addr string) (string, uint32) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint32(0)
|
||||
}
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint32(port)
|
||||
}
|
||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return
|
||||
}
|
||||
buffer.WriteString(str)
|
||||
}
|
||||
|
||||
//func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool {
|
||||
// session.clientsMutex.Lock()
|
||||
// defer session.clientsMutex.Unlock()
|
||||
//
|
||||
// if _, exists := session.Clients[newSlug]; exists && newSlug != oldSlug {
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// client, ok := session.Clients[oldSlug]
|
||||
// if !ok {
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// delete(session.Clients, oldSlug)
|
||||
// f.SlugManager.Set(newSlug)
|
||||
// session.Clients[newSlug] = client
|
||||
// return true
|
||||
//}
|
||||
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(port))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
}
|
||||
writeSSHString(&buf, host)
|
||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@ -3,12 +3,9 @@ package session
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/types"
|
||||
|
||||
@ -17,10 +14,7 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type UserConnection struct {
|
||||
Reader io.Reader
|
||||
Writer net.Conn
|
||||
}
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
for req := range GlobalRequest {
|
||||
@ -157,23 +151,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
||||
s.HandleTCPForward(req, addr, portToBind)
|
||||
}
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
}
|
||||
if port < 1024 && port != 0 {
|
||||
return true
|
||||
}
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
|
||||
s.Forwarder.SetType(types.HTTP)
|
||||
s.Forwarder.SetForwardedPort(portToBind)
|
||||
@ -237,7 +214,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
||||
s.Interaction.ShowWelcomeMessage()
|
||||
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s:%d \r\n", s.Forwarder.GetTunnelType(), utils.Getenv("domain"), s.Forwarder.GetForwardedPort()))
|
||||
|
||||
go s.acceptTCPConnections()
|
||||
go s.Forwarder.AcceptTCPConnections()
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||
@ -253,37 +230,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHSession) acceptTCPConnections() {
|
||||
for {
|
||||
conn, err := s.Forwarder.GetListener().Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
|
||||
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), s.Forwarder.GetForwardedPort())
|
||||
channel, reqs, err := s.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to reply to request: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go s.HandleForwardedConnection(conn, channel, conn.RemoteAddr())
|
||||
}
|
||||
}
|
||||
|
||||
func generateUniqueSlug() string {
|
||||
maxAttempts := 5
|
||||
|
||||
@ -303,30 +249,6 @@ func generateUniqueSlug() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *SSHSession) HandleForwardedConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
defer func(src ssh.Channel) {
|
||||
err := src.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing connection: %v", err)
|
||||
}
|
||||
}(src)
|
||||
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(src, dst)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := io.Copy(dst, src)
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error copying from channel to conn.Writer: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func readSSHString(reader *bytes.Reader) (string, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
||||
@ -339,40 +261,17 @@ func readSSHString(reader *bytes.Reader) (string, error) {
|
||||
return string(strBytes), nil
|
||||
}
|
||||
|
||||
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(port))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
}
|
||||
writeSSHString(&buf, host)
|
||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
if port < 1024 && port != 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return
|
||||
}
|
||||
buffer.WriteString(str)
|
||||
}
|
||||
|
||||
func ParseAddr(addr string) (string, uint32) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint32(0)
|
||||
}
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint32(port)
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@ -14,22 +14,27 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var forbiddenSlug = []string{
|
||||
"ping",
|
||||
}
|
||||
|
||||
type Lifecycle interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type InteractionController interface {
|
||||
type Controller interface {
|
||||
SendMessage(message string)
|
||||
HandleUserInput()
|
||||
HandleCommand(command string, commandBuffer *bytes.Buffer)
|
||||
HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer)
|
||||
HandleCommand(command string)
|
||||
HandleSlugEditMode(connection ssh.Channel, char byte)
|
||||
HandleSlugSave(conn ssh.Channel)
|
||||
HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer)
|
||||
HandleSlugCancel(connection ssh.Channel)
|
||||
HandleSlugUpdateError()
|
||||
ShowWelcomeMessage()
|
||||
DisplaySlugEditor()
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
SetSlugModificator(func(oldSlug, newSlug string) bool)
|
||||
}
|
||||
|
||||
type Forwarder interface {
|
||||
@ -46,6 +51,7 @@ type Interaction struct {
|
||||
SlugManager slug.Manager
|
||||
Forwarder Forwarder
|
||||
Lifecycle Lifecycle
|
||||
updateClientSlug func(oldSlug, newSlug string) bool
|
||||
}
|
||||
|
||||
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||
@ -67,7 +73,6 @@ func (i *Interaction) SendMessage(message string) {
|
||||
}
|
||||
|
||||
func (i *Interaction) HandleUserInput() {
|
||||
var commandBuffer bytes.Buffer
|
||||
buf := make([]byte, 1)
|
||||
i.EditMode = false
|
||||
|
||||
@ -84,42 +89,42 @@ func (i *Interaction) HandleUserInput() {
|
||||
char := buf[0]
|
||||
|
||||
if i.EditMode {
|
||||
i.HandleSlugEditMode(i.channel, char, &commandBuffer)
|
||||
i.HandleSlugEditMode(i.channel, char)
|
||||
continue
|
||||
}
|
||||
|
||||
i.SendMessage(string(buf[:n]))
|
||||
|
||||
if char == 8 || char == 127 {
|
||||
if commandBuffer.Len() > 0 {
|
||||
commandBuffer.Truncate(commandBuffer.Len() - 1)
|
||||
if i.CommandBuffer.Len() > 0 {
|
||||
i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1)
|
||||
i.SendMessage("\b \b")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if char == '/' {
|
||||
commandBuffer.Reset()
|
||||
commandBuffer.WriteByte(char)
|
||||
i.CommandBuffer.Reset()
|
||||
i.CommandBuffer.WriteByte(char)
|
||||
continue
|
||||
}
|
||||
|
||||
if commandBuffer.Len() > 0 {
|
||||
if i.CommandBuffer.Len() > 0 {
|
||||
if char == 13 {
|
||||
i.HandleCommand(commandBuffer.String(), &commandBuffer)
|
||||
i.HandleCommand(i.CommandBuffer.String())
|
||||
continue
|
||||
}
|
||||
commandBuffer.WriteByte(char)
|
||||
i.CommandBuffer.WriteByte(char)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte, commandBuffer *bytes.Buffer) {
|
||||
func (i *Interaction) HandleSlugEditMode(connection ssh.Channel, char byte) {
|
||||
if char == 13 {
|
||||
i.HandleSlugSave(connection)
|
||||
} else if char == 27 {
|
||||
i.HandleSlugCancel(connection, commandBuffer)
|
||||
i.HandleSlugCancel(connection)
|
||||
} else if char == 8 || char == 127 {
|
||||
if len(i.EditSlug) > 0 {
|
||||
i.EditSlug = (i.EditSlug)[:len(i.EditSlug)-1]
|
||||
@ -160,13 +165,13 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
|
||||
return
|
||||
}
|
||||
if isValid {
|
||||
//oldSlug := i.SlugManager.Get()
|
||||
oldSlug := i.SlugManager.Get()
|
||||
newSlug := i.EditSlug
|
||||
|
||||
//if !i.updateClientSlug(oldSlug, newSlug) {
|
||||
// i.HandleSlugUpdateError()
|
||||
// return
|
||||
//}
|
||||
if !i.updateClientSlug(oldSlug, newSlug) {
|
||||
i.HandleSlugUpdateError()
|
||||
return
|
||||
}
|
||||
|
||||
_, err := connection.Write([]byte("\r\n\r\n✅ SUBDOMAIN UPDATED ✅\r\n\r\n"))
|
||||
if err != nil {
|
||||
@ -251,7 +256,7 @@ func (i *Interaction) HandleSlugSave(connection ssh.Channel) {
|
||||
i.CommandBuffer.Reset()
|
||||
}
|
||||
|
||||
func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *bytes.Buffer) {
|
||||
func (i *Interaction) HandleSlugCancel(connection ssh.Channel) {
|
||||
i.EditMode = false
|
||||
_, err := connection.Write([]byte("\033[H\033[2J"))
|
||||
if err != nil {
|
||||
@ -278,7 +283,7 @@ func (i *Interaction) HandleSlugCancel(connection ssh.Channel, commandBuffer *by
|
||||
}
|
||||
i.ShowWelcomeMessage()
|
||||
|
||||
commandBuffer.Reset()
|
||||
i.CommandBuffer.Reset()
|
||||
}
|
||||
|
||||
func (i *Interaction) HandleSlugUpdateError() {
|
||||
@ -296,7 +301,7 @@ func (i *Interaction) HandleSlugUpdateError() {
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer) {
|
||||
func (i *Interaction) HandleCommand(command string) {
|
||||
switch command {
|
||||
case "/bye":
|
||||
i.SendMessage("\r\nClosing connection...")
|
||||
@ -307,7 +312,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer)
|
||||
}
|
||||
return
|
||||
case "/help":
|
||||
i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug")
|
||||
i.SendMessage("\r\nAvailable commands: /bye, /help, /clear, /slug\r\n")
|
||||
case "/clear":
|
||||
i.SendMessage("\033[H\033[2J")
|
||||
i.ShowWelcomeMessage()
|
||||
@ -323,7 +328,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer)
|
||||
}
|
||||
case "/slug":
|
||||
if i.Forwarder.GetTunnelType() != types.HTTP {
|
||||
i.SendMessage((fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType())))
|
||||
i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains", i.Forwarder.GetTunnelType()))
|
||||
} else {
|
||||
i.EditMode = true
|
||||
i.EditSlug = i.SlugManager.Get()
|
||||
@ -335,7 +340,7 @@ func (i *Interaction) HandleCommand(command string, commandBuffer *bytes.Buffer)
|
||||
i.SendMessage("Unknown command")
|
||||
}
|
||||
|
||||
commandBuffer.Reset()
|
||||
i.CommandBuffer.Reset()
|
||||
}
|
||||
|
||||
func (i *Interaction) ShowWelcomeMessage() {
|
||||
@ -401,6 +406,10 @@ func (i *Interaction) DisplaySlugEditor() {
|
||||
i.SendMessage("\r\n\r\n")
|
||||
}
|
||||
|
||||
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) bool) {
|
||||
i.updateClientSlug = modificator
|
||||
}
|
||||
|
||||
func centerText(text string, width int) string {
|
||||
padding := (width - len(text)) / 2
|
||||
if padding < 0 {
|
||||
@ -408,6 +417,7 @@ func centerText(text string, width int) string {
|
||||
}
|
||||
return strings.Repeat(" ", padding) + text + strings.Repeat(" ", width-len(text)-padding)
|
||||
}
|
||||
|
||||
func isValidSlug(slug string) bool {
|
||||
if len(slug) < 3 || len(slug) > 20 {
|
||||
return false
|
||||
@ -436,10 +446,6 @@ func waitForKeyPress(connection ssh.Channel) {
|
||||
}
|
||||
}
|
||||
|
||||
var forbiddenSlug = []string{
|
||||
"ping",
|
||||
}
|
||||
|
||||
func isForbiddenSlug(slug string) bool {
|
||||
for _, s := range forbiddenSlug {
|
||||
if slug == s {
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
@ -20,6 +21,7 @@ type Interaction interface {
|
||||
type Forwarder interface {
|
||||
Close() error
|
||||
GetTunnelType() types.TunnelType
|
||||
GetForwardedPort() uint16
|
||||
}
|
||||
|
||||
type Lifecycle struct {
|
||||
@ -30,6 +32,11 @@ type Lifecycle struct {
|
||||
Interaction Interaction
|
||||
Forwarder Forwarder
|
||||
SlugManager slug.Manager
|
||||
unregisterClient func(slug string)
|
||||
}
|
||||
|
||||
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
|
||||
l.unregisterClient = unregisterClient
|
||||
}
|
||||
|
||||
type SessionLifecycle interface {
|
||||
@ -39,6 +46,7 @@ type SessionLifecycle interface {
|
||||
GetConnection() ssh.Conn
|
||||
GetChannel() ssh.Channel
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetUnregisterClient(unregisterClient func(slug string))
|
||||
}
|
||||
|
||||
func (l *Lifecycle) GetChannel() ssh.Channel {
|
||||
@ -84,15 +92,9 @@ func (l *Lifecycle) WaitForRunningStatus() {
|
||||
|
||||
func (l *Lifecycle) Close() error {
|
||||
err := l.Forwarder.Close()
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return err
|
||||
}
|
||||
//if s.Forwarder.Listener != nil {
|
||||
// err := s.Forwarder.Listener.Close()
|
||||
// if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
// return err
|
||||
// }
|
||||
//}
|
||||
|
||||
if l.Channel != nil {
|
||||
err := l.Channel.Close()
|
||||
@ -108,17 +110,17 @@ func (l *Lifecycle) Close() error {
|
||||
}
|
||||
}
|
||||
|
||||
//clientSlug := l.SlugManager.Get()
|
||||
//if clientSlug != "" {
|
||||
// unregisterClient(clientSlug)
|
||||
//}
|
||||
clientSlug := l.SlugManager.Get()
|
||||
if clientSlug != "" {
|
||||
l.unregisterClient(clientSlug)
|
||||
}
|
||||
|
||||
//if l.Forwarder.GetType() == "TCP" && s.Forwarder.Listener != nil {
|
||||
// err := portUtil.Manager.SetPortStatus(s.Forwarder.ForwardedPort, false)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//}
|
||||
if l.Forwarder.GetTunnelType() == types.TCP {
|
||||
err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"sync"
|
||||
"tunnel_pls/session/forwarder"
|
||||
@ -12,11 +13,12 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Session interface {
|
||||
lifecycle.Lifecycle
|
||||
interaction.InteractionController
|
||||
forwarder.ForwardingController
|
||||
var (
|
||||
clientsMutex sync.RWMutex
|
||||
Clients = make(map[string]*SSHSession)
|
||||
)
|
||||
|
||||
type Session interface {
|
||||
HandleGlobalRequest(ch <-chan *ssh.Request)
|
||||
HandleTCPIPForward(req *ssh.Request)
|
||||
HandleHTTPForward(req *ssh.Request, port uint16)
|
||||
@ -25,7 +27,7 @@ type Session interface {
|
||||
|
||||
type SSHSession struct {
|
||||
Lifecycle lifecycle.SessionLifecycle
|
||||
Interaction interaction.InteractionController
|
||||
Interaction interaction.Controller
|
||||
Forwarder forwarder.ForwardingController
|
||||
SlugManager slug.Manager
|
||||
}
|
||||
@ -39,7 +41,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
SlugManager: slugManager,
|
||||
}
|
||||
interactionManager := &interaction.Interaction{
|
||||
CommandBuffer: nil,
|
||||
CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
|
||||
EditMode: false,
|
||||
EditSlug: "",
|
||||
SlugManager: slugManager,
|
||||
@ -54,13 +56,18 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
Forwarder: forwarderManager,
|
||||
SlugManager: slugManager,
|
||||
}
|
||||
|
||||
interactionManager.SetLifecycle(lifecycleManager)
|
||||
interactionManager.SetSlugModificator(updateClientSlug)
|
||||
forwarderManager.SetLifecycle(lifecycleManager)
|
||||
lifecycleManager.SetUnregisterClient(unregisterClient)
|
||||
|
||||
session := &SSHSession{
|
||||
Lifecycle: lifecycleManager,
|
||||
Interaction: interactionManager,
|
||||
Forwarder: forwarderManager,
|
||||
SlugManager: slugManager,
|
||||
}
|
||||
interactionManager.SetLifecycle(lifecycleManager)
|
||||
|
||||
go func() {
|
||||
go session.Lifecycle.WaitForRunningStatus()
|
||||
@ -70,7 +77,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
if session.Lifecycle.GetChannel() == nil {
|
||||
session.Lifecycle.SetChannel(ch)
|
||||
session.Interaction.SetChannel(ch)
|
||||
//session.Interaction.channel = ch
|
||||
session.Lifecycle.SetStatus(types.SETUP)
|
||||
go session.HandleGlobalRequest(forwardingReq)
|
||||
}
|
||||
@ -84,10 +90,24 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
}()
|
||||
}
|
||||
|
||||
var (
|
||||
clientsMutex sync.RWMutex
|
||||
Clients = make(map[string]*SSHSession)
|
||||
)
|
||||
func updateClientSlug(oldSlug, newSlug string) bool {
|
||||
clientsMutex.Lock()
|
||||
defer clientsMutex.Unlock()
|
||||
|
||||
if _, exists := Clients[newSlug]; exists && newSlug != oldSlug {
|
||||
return false
|
||||
}
|
||||
|
||||
client, ok := Clients[oldSlug]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(Clients, oldSlug)
|
||||
client.SlugManager.Set(newSlug)
|
||||
Clients[newSlug] = client
|
||||
return true
|
||||
}
|
||||
|
||||
func registerClient(slug string, session *SSHSession) bool {
|
||||
clientsMutex.Lock()
|
||||
|
||||
Reference in New Issue
Block a user