refactor: explicit initialization and dependency injection
- Replace init() with config.Load() function when loading env variables - Inject portRegistry into session, server, and lifecycle structs - Inject sessionRegistry directly into interaction and lifecycle - Remove SetSessionRegistry function and global port variables
This commit is contained in:
@@ -1,19 +1,17 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func Load() error {
|
||||||
if _, err := os.Stat(".env"); err == nil {
|
if _, err := os.Stat(".env"); err == nil {
|
||||||
if err := godotenv.Load(".env"); err != nil {
|
return godotenv.Load(".env")
|
||||||
log.Printf("Warning: Failed to load .env file: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Getenv(key, defaultValue string) string {
|
func Getenv(key, defaultValue string) string {
|
||||||
|
|||||||
+9
-32
@@ -3,53 +3,30 @@ package port
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Registry interface {
|
||||||
AddPortRange(startPort, endPort uint16) error
|
AddPortRange(startPort, endPort uint16) error
|
||||||
GetUnassignedPort() (uint16, bool)
|
GetUnassignedPort() (uint16, bool)
|
||||||
SetPortStatus(port uint16, assigned bool) error
|
SetPortStatus(port uint16, assigned bool) error
|
||||||
ClaimPort(port uint16) (claimed bool)
|
ClaimPort(port uint16) (claimed bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type manager struct {
|
type registry struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ports map[uint16]bool
|
ports map[uint16]bool
|
||||||
sortedPorts []uint16
|
sortedPorts []uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
var Default Manager = &manager{
|
func New() Registry {
|
||||||
|
return ®istry{
|
||||||
ports: make(map[uint16]bool),
|
ports: make(map[uint16]bool),
|
||||||
sortedPorts: []uint16{},
|
sortedPorts: []uint16{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func (pm *registry) AddPortRange(startPort, endPort uint16) error {
|
||||||
rawRange := config.Getenv("ALLOWED_PORTS", "")
|
|
||||||
if rawRange == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
splitRange := strings.Split(rawRange, "-")
|
|
||||||
if len(splitRange) != 2 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
start, err := strconv.ParseUint(splitRange[0], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
end, err := strconv.ParseUint(splitRange[1], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = Default.AddPortRange(uint16(start), uint16(end))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
|
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -68,7 +45,7 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *manager) GetUnassignedPort() (uint16, bool) {
|
func (pm *registry) GetUnassignedPort() (uint16, bool) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -80,7 +57,7 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) {
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
|
func (pm *registry) SetPortStatus(port uint16, assigned bool) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -88,7 +65,7 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *manager) ClaimPort(port uint16) (claimed bool) {
|
func (pm *registry) ClaimPort(port uint16) (claimed bool) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ import (
|
|||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/grpc/client"
|
"tunnel_pls/internal/grpc/client"
|
||||||
"tunnel_pls/internal/key"
|
"tunnel_pls/internal/key"
|
||||||
|
"tunnel_pls/internal/port"
|
||||||
"tunnel_pls/server"
|
"tunnel_pls/server"
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
"tunnel_pls/version"
|
"tunnel_pls/version"
|
||||||
@@ -32,6 +34,12 @@ func main() {
|
|||||||
|
|
||||||
log.Printf("Starting %s", version.GetVersion())
|
log.Printf("Starting %s", version.GetVersion())
|
||||||
|
|
||||||
|
err := config.Load()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load configuration: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
|
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
|
||||||
isNodeMode := mode == "node"
|
isNodeMode := mode == "node"
|
||||||
|
|
||||||
@@ -41,7 +49,7 @@ func main() {
|
|||||||
go func() {
|
go func() {
|
||||||
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
||||||
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
||||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||||
log.Printf("pprof server error: %v", err)
|
log.Printf("pprof server error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -53,7 +61,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sshKeyPath := "certs/ssh/id_rsa"
|
sshKeyPath := "certs/ssh/id_rsa"
|
||||||
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
||||||
log.Fatalf("Failed to generate SSH key: %s", err)
|
log.Fatalf("Failed to generate SSH key: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,9 +115,33 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
portManager := port.New()
|
||||||
|
rawRange := config.Getenv("ALLOWED_PORTS", "")
|
||||||
|
if rawRange != "" {
|
||||||
|
splitRange := strings.Split(rawRange, "-")
|
||||||
|
if len(splitRange) == 2 {
|
||||||
|
var start, end uint64
|
||||||
|
start, err = strconv.ParseUint(splitRange[0], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to parse start port: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
end, err = strconv.ParseUint(splitRange[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to parse end port: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil {
|
||||||
|
log.Fatalf("Failed to add port range: %s", err)
|
||||||
|
}
|
||||||
|
log.Printf("PortRegistry range configured: %d-%d", start, end)
|
||||||
|
} else {
|
||||||
|
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
|
||||||
|
}
|
||||||
|
}
|
||||||
var app server.Server
|
var app server.Server
|
||||||
go func() {
|
go func() {
|
||||||
app, err = server.New(sshConfig, sessionRegistry, grpcClient)
|
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("failed to start server: %s", err)
|
errChan <- fmt.Errorf("failed to start server: %s", err)
|
||||||
return
|
return
|
||||||
|
|||||||
+7
-4
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/grpc/client"
|
"tunnel_pls/internal/grpc/client"
|
||||||
|
"tunnel_pls/internal/port"
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
@@ -21,11 +22,12 @@ type Server interface {
|
|||||||
type server struct {
|
type server struct {
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
config *ssh.ServerConfig
|
config *ssh.ServerConfig
|
||||||
sessionRegistry session.Registry
|
|
||||||
grpcClient client.Client
|
grpcClient client.Client
|
||||||
|
sessionRegistry session.Registry
|
||||||
|
portRegistry port.Registry
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client) (Server, error) {
|
func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) {
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to listen on port 2200: %v", err)
|
log.Fatalf("failed to listen on port 2200: %v", err)
|
||||||
@@ -50,8 +52,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie
|
|||||||
return &server{
|
return &server{
|
||||||
listener: listener,
|
listener: listener,
|
||||||
config: sshConfig,
|
config: sshConfig,
|
||||||
sessionRegistry: sessionRegistry,
|
|
||||||
grpcClient: grpcClient,
|
grpcClient: grpcClient,
|
||||||
|
sessionRegistry: sessionRegistry,
|
||||||
|
portRegistry: portRegistry,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,7 +106,7 @@ func (s *server) handleConnection(conn net.Conn) {
|
|||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
log.Println("SSH connection established:", sshConn.User())
|
log.Println("SSH connection established:", sshConn.User())
|
||||||
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user)
|
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
|
||||||
err = sshSession.Start()
|
err = sshSession.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("SSH session ended with error: %v", err)
|
log.Printf("SSH session ended with error: %v", err)
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ type Interaction interface {
|
|||||||
Mode() types.Mode
|
Mode() types.Mode
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
SetLifecycle(lifecycle Lifecycle)
|
SetLifecycle(lifecycle Lifecycle)
|
||||||
SetSessionRegistry(registry SessionRegistry)
|
|
||||||
SetMode(m types.Mode)
|
SetMode(m types.Mode)
|
||||||
SetWH(w, h int)
|
SetWH(w, h int)
|
||||||
Start()
|
Start()
|
||||||
@@ -80,24 +79,20 @@ func (i *interaction) SetWH(w, h int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(slug slug.Slug, forwarder Forwarder) Interaction {
|
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry) Interaction {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &interaction{
|
return &interaction{
|
||||||
channel: nil,
|
channel: nil,
|
||||||
slug: slug,
|
slug: slug,
|
||||||
forwarder: forwarder,
|
forwarder: forwarder,
|
||||||
lifecycle: nil,
|
lifecycle: nil,
|
||||||
sessionRegistry: nil,
|
sessionRegistry: sessionRegistry,
|
||||||
program: nil,
|
program: nil,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
|
|
||||||
i.sessionRegistry = registry
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
|
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||||
i.lifecycle = lifecycle
|
i.lifecycle = lifecycle
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,41 +28,43 @@ type lifecycle struct {
|
|||||||
conn ssh.Conn
|
conn ssh.Conn
|
||||||
channel ssh.Channel
|
channel ssh.Channel
|
||||||
forwarder Forwarder
|
forwarder Forwarder
|
||||||
sessionRegistry SessionRegistry
|
|
||||||
slug slug.Slug
|
slug slug.Slug
|
||||||
startedAt time.Time
|
startedAt time.Time
|
||||||
|
sessionRegistry SessionRegistry
|
||||||
|
portRegistry portUtil.Registry
|
||||||
user string
|
user string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
|
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||||
return &lifecycle{
|
return &lifecycle{
|
||||||
status: types.INITIALIZING,
|
status: types.INITIALIZING,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
channel: nil,
|
channel: nil,
|
||||||
forwarder: forwarder,
|
forwarder: forwarder,
|
||||||
slug: slugManager,
|
slug: slugManager,
|
||||||
sessionRegistry: nil,
|
|
||||||
startedAt: time.Now(),
|
startedAt: time.Now(),
|
||||||
|
sessionRegistry: sessionRegistry,
|
||||||
|
portRegistry: port,
|
||||||
user: user,
|
user: user,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
|
|
||||||
l.sessionRegistry = registry
|
|
||||||
}
|
|
||||||
|
|
||||||
type Lifecycle interface {
|
type Lifecycle interface {
|
||||||
Connection() ssh.Conn
|
Connection() ssh.Conn
|
||||||
Channel() ssh.Channel
|
Channel() ssh.Channel
|
||||||
|
PortRegistry() portUtil.Registry
|
||||||
User() string
|
User() string
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
SetSessionRegistry(registry SessionRegistry)
|
|
||||||
SetStatus(status types.Status)
|
SetStatus(status types.Status)
|
||||||
IsActive() bool
|
IsActive() bool
|
||||||
StartedAt() time.Time
|
StartedAt() time.Time
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *lifecycle) PortRegistry() portUtil.Registry {
|
||||||
|
return l.portRegistry
|
||||||
|
}
|
||||||
|
|
||||||
func (l *lifecycle) User() string {
|
func (l *lifecycle) User() string {
|
||||||
return l.user
|
return l.user
|
||||||
}
|
}
|
||||||
@@ -116,7 +118,7 @@ func (l *lifecycle) Close() error {
|
|||||||
l.sessionRegistry.Remove(key)
|
l.sessionRegistry.Remove(key)
|
||||||
|
|
||||||
if tunnelType == types.TCP {
|
if tunnelType == types.TCP {
|
||||||
if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||||
firstErr = err
|
firstErr = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+12
-14
@@ -54,16 +54,14 @@ type session struct {
|
|||||||
|
|
||||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||||
|
|
||||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
|
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session {
|
||||||
slugManager := slug.New()
|
slugManager := slug.New()
|
||||||
forwarderManager := forwarder.New(slugManager)
|
forwarderManager := forwarder.New(slugManager)
|
||||||
interactionManager := interaction.New(slugManager, forwarderManager)
|
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry)
|
||||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
|
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
|
||||||
|
|
||||||
interactionManager.SetLifecycle(lifecycleManager)
|
interactionManager.SetLifecycle(lifecycleManager)
|
||||||
forwarderManager.SetLifecycle(lifecycleManager)
|
forwarderManager.SetLifecycle(lifecycleManager)
|
||||||
interactionManager.SetSessionRegistry(sessionRegistry)
|
|
||||||
lifecycleManager.SetSessionRegistry(sessionRegistry)
|
|
||||||
|
|
||||||
return &session{
|
return &session{
|
||||||
initialReq: initialReq,
|
initialReq: initialReq,
|
||||||
@@ -135,7 +133,7 @@ func (s *session) Start() error {
|
|||||||
|
|
||||||
tcpipReq := s.waitForTCPIPForward()
|
tcpipReq := s.waitForTCPIPForward()
|
||||||
if tcpipReq == nil {
|
if tcpipReq == nil {
|
||||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -234,7 +232,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) HandleTCPIPForward(req *ssh.Request) {
|
func (s *session) HandleTCPIPForward(req *ssh.Request) {
|
||||||
log.Println("Port forwarding request detected")
|
log.Println("PortRegistry forwarding request detected")
|
||||||
|
|
||||||
fail := func(msg string) {
|
fail := func(msg string) {
|
||||||
log.Println(msg)
|
log.Println(msg)
|
||||||
@@ -262,13 +260,13 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rawPortToBind > 65535 {
|
if rawPortToBind > 65535 {
|
||||||
fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind))
|
fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
portToBind := uint16(rawPortToBind)
|
portToBind := uint16(rawPortToBind)
|
||||||
if isBlockedPort(portToBind) {
|
if isBlockedPort(portToBind) {
|
||||||
fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind))
|
fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,7 +338,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
|||||||
s.registry.Remove(*key)
|
s.registry.Remove(*key)
|
||||||
}
|
}
|
||||||
if port != 0 {
|
if port != 0 {
|
||||||
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
|
if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil {
|
||||||
log.Printf("Failed to reset port status: %v", setErr)
|
log.Printf("Failed to reset port status: %v", setErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -356,7 +354,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
|||||||
}
|
}
|
||||||
|
|
||||||
if portToBind == 0 {
|
if portToBind == 0 {
|
||||||
unassigned, ok := portUtil.Default.GetUnassignedPort()
|
unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort()
|
||||||
if !ok {
|
if !ok {
|
||||||
fail("No available port")
|
fail("No available port")
|
||||||
return
|
return
|
||||||
@@ -364,15 +362,15 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
|||||||
portToBind = unassigned
|
portToBind = unassigned
|
||||||
}
|
}
|
||||||
|
|
||||||
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
|
if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed {
|
||||||
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil)
|
cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user