refactor: restructure session initialization to avoid circular references

This commit is contained in:
2025-12-04 22:48:15 +07:00
parent 039e979142
commit 7a31047bb9
7 changed files with 229 additions and 219 deletions

View File

@ -1,9 +1,17 @@
package forwarder
import (
"bytes"
"encoding/binary"
"errors"
"io"
"log"
"net"
"strconv"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type Forwarder struct {
@ -11,14 +19,83 @@ type Forwarder struct {
TunnelType types.TunnelType
ForwardedPort uint16
SlugManager slug.Manager
Lifecycle Lifecycle
}
type Lifecycle interface {
GetConnection() ssh.Conn
}
type ForwardingController interface {
AcceptTCPConnections()
SetType(tunnelType types.TunnelType)
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
GetListener() net.Listener
Close() error
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
}
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
f.Lifecycle = lifecycle
}
func (f *Forwarder) AcceptTCPConnections() {
panic("implement me")
for {
conn, err := f.GetListener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
originHost, originPort := ParseAddr(conn.RemoteAddr().String())
payload := createForwardedTCPIPPayload(originHost, uint16(originPort), f.GetForwardedPort())
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go func() {
for req := range reqs {
err := req.Reply(false, nil)
if err != nil {
log.Printf("Failed to reply to request: %v", err)
return
}
}
}()
go f.HandleConnection(conn, channel, conn.RemoteAddr())
}
}
func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool {
panic("implement me")
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func(src ssh.Channel) {
err := src.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing connection: %v", err)
}
}(src)
log.Printf("Handling new forwarded connection from %s", remoteAddr)
go func() {
_, err := io.Copy(src, dst)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying from conn.Reader to channel: %v", err)
}
}()
_, err := io.Copy(dst, src)
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error copying from channel to conn.Writer: %v", err)
}
return
}
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
@ -52,33 +129,39 @@ func (f *Forwarder) Close() error {
return nil
}
type ForwardingController interface {
AcceptTCPConnections()
UpdateClientSlug(oldSlug, newSlug string) bool
SetType(tunnelType types.TunnelType)
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
GetListener() net.Listener
Close() error
func ParseAddr(addr string) (string, uint32) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint32(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint32(port)
}
func writeSSHString(buffer *bytes.Buffer, str string) {
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
}
buffer.WriteString(str)
}
//func (f *Forwarder) UpdateClientSlug(oldSlug, newSlug string) bool {
// session.clientsMutex.Lock()
// defer session.clientsMutex.Unlock()
//
// if _, exists := session.Clients[newSlug]; exists && newSlug != oldSlug {
// return false
// }
//
// client, ok := session.Clients[oldSlug]
// if !ok {
// return false
// }
//
// delete(session.Clients, oldSlug)
// f.SlugManager.Set(newSlug)
// session.Clients[newSlug] = client
// return true
//}
func createForwardedTCPIPPayload(host string, originPort, port uint16) []byte {
var buf bytes.Buffer
writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(port))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
}
writeSSHString(&buf, host)
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
}
return buf.Bytes()
}