5 Commits

Author SHA1 Message Date
abd103b5ab fix(port): add atomic ClaimPort() to prevent race condition
All checks were successful
Docker Build and Push / build-and-push-tags (push) Successful in 3m23s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- Replace GetPortStatus/SetPortStatus calls with atomic ClaimPort() operation.
- Fixed a logic error when handling headless tunneling.
2026-01-12 18:25:35 +07:00
560c98b869 refactor: consolidate error handling with fail() function in session handlers
All checks were successful
Docker Build and Push / build-and-push-tags (push) Successful in 3m21s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- Replace repetitive error handling code with fail() function in HandleGlobalRequest
- Standardize error response pattern across all handler methods
- Improve code maintainability and reduce duplication
2026-01-12 14:42:42 +07:00
e1f5d73e03 feat: add headless mode support for SSH -N connections
All checks were successful
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 3m3s
- use s.lifecycle.GetConnection().Wait() to block until SSH connection closes
- Prevent premature session closure in headless mode

In headless mode (ssh -N), there's no channel interaction to block on,
so the session would immediately return and close. Now blocking on
conn.Wait() keeps the session alive until the client disconnects.
2026-01-11 15:21:11 +07:00
19fd6d59d2 Merge pull request 'main' (#62) from main into staging
All checks were successful
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 3m32s
Reviewed-on: #62
2026-01-09 12:15:30 +00:00
e3988b339f Merge pull request 'fix(deps): update module github.com/caddyserver/certmagic to v0.25.1' (#61) from renovate/github.com-caddyserver-certmagic-0.x into main
All checks were successful
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 3m21s
Reviewed-on: #61
2026-01-09 12:15:05 +00:00
6 changed files with 163 additions and 160 deletions

View File

@@ -13,7 +13,7 @@ type Manager interface {
AddPortRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error
GetPortStatus(port uint16) (bool, bool)
ClaimPort(port uint16) (claimed bool)
}
type manager struct {
@@ -74,7 +74,6 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) {
for _, port := range pm.sortedPorts {
if !pm.ports[port] {
pm.ports[port] = true
return port, true
}
}
@@ -89,10 +88,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil
}
func (pm *manager) GetPortStatus(port uint16) (bool, bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
func (pm *manager) ClaimPort(port uint16) (claimed bool) {
pm.mu.Lock()
defer pm.mu.Unlock()
status, exists := pm.ports[port]
return status, exists
if exists && status {
return false
}
if !exists {
pm.ports[port] = true
return true
}
pm.ports[port] = true
return true
}

View File

@@ -90,7 +90,6 @@ func (s *Server) handleConnection(conn net.Conn) {
user = u
cancel()
}
log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user)
err = sshSession.Start()

View File

@@ -59,143 +59,79 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected")
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader)
if err != nil {
log.Println("Failed to read address from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
fail(fmt.Sprintf("Failed to read address from payload: %v", err))
return
}
var rawPortToBind uint32
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
log.Println("Failed to read port from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
fail(fmt.Sprintf("Failed to read port from payload: %v", err))
return
}
if rawPortToBind > 65535 {
log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind))
return
}
portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) {
log.Printf("Port %d is blocked or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind))
return
}
if portToBind == 80 || portToBind == 443 {
switch portToBind {
case 80, 443:
s.HandleHTTPForward(req, portToBind)
return
}
if portToBind == 0 {
unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign
if !success {
log.Println("No available port")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
log.Printf("Port %d is already in use or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
err = portUtil.Default.SetPortStatus(portToBind, true)
if err != nil {
log.Println("Failed to set port status:", err)
return
}
default:
s.HandleTCPForward(req, addr, portToBind)
}
}
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
slug := random.GenerateRandomString(20)
key := types.SessionKey{Id: slug, Type: types.HTTP}
if !s.registry.Register(key, s) {
log.Printf("Failed to register client with slug: %s", slug)
err := req.Reply(false, nil)
if err != nil {
fail := func(msg string, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
}
slug := random.GenerateRandomString(20)
key := types.SessionKey{Id: slug, Type: types.HTTP}
if !s.registry.Register(key, s) {
fail(fmt.Sprintf("Failed to register client with slug: %s", slug), nil)
return
}
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil {
log.Println("Failed to write port to buffer:", err)
s.registry.Remove(key)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key)
return
}
log.Printf("HTTP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
log.Println("Failed to reply to request:", err)
s.registry.Remove(key)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
fail(fmt.Sprintf("Failed to reply to request: %v", err), &key)
return
}
@@ -203,76 +139,80 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(slug)
s.lifecycle.SetStatus(types.RUNNING)
s.interaction.Start()
}
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
log.Printf("Port %d is already in use or restricted", portToBind)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = req.Reply(false, nil)
if err != nil {
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if port != 0 {
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
}
if listener != nil {
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %v", closeErr)
}
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
}
if portToBind == 0 {
unassigned, ok := portUtil.Default.GetUnassignedPort()
if !ok {
fail("No available port")
return
}
portToBind = unassigned
}
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
return
}
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil)
return
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
if !s.registry.Register(key, s) {
log.Printf("Failed to register TCP client with id: %s", key.Id)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %s", closeErr)
}
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil)
return
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil {
log.Println("Failed to write port to buffer:", err)
s.registry.Remove(key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key)
return
}
log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
log.Println("Failed to reply to request:", err)
s.registry.Remove(key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key)
return
}
@@ -282,7 +222,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
s.slugManager.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections()
s.interaction.Start()
}
func readSSHString(reader *bytes.Reader) (string, error) {

View File

@@ -37,6 +37,9 @@ type Controller interface {
SetWH(w, h int)
Redraw()
SetSessionRegistry(registry SessionRegistry)
SetMode(m types.Mode)
GetMode() types.Mode
Send(message string) error
}
type Forwarder interface {
@@ -54,8 +57,24 @@ type Interaction struct {
program *tea.Program
ctx context.Context
cancel context.CancelFunc
mode types.Mode
}
func (i *Interaction) SetMode(m types.Mode) {
i.mode = m
}
func (i *Interaction) GetMode() types.Mode {
return i.mode
}
func (i *Interaction) Send(message string) error {
if i.channel != nil {
_, err := i.channel.Write([]byte(message))
return err
}
return nil
}
func (i *Interaction) SetWH(w, h int) {
if i.program != nil {
i.program.Send(tea.WindowSizeMsg{
@@ -749,6 +768,9 @@ func (m *model) View() string {
}
func (i *Interaction) Start() {
if i.mode == types.HEADLESS {
return
}
lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost")

View File

@@ -9,6 +9,7 @@ import (
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
@@ -87,7 +88,14 @@ func (s *SSHSession) Detail() Detail {
}
func (s *SSHSession) Start() error {
channel := <-s.sshReqChannel
var channel ssh.NewChannel
var ok bool
select {
case channel, ok = <-s.sshReqChannel:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
@@ -95,23 +103,41 @@ func (s *SSHSession) Start() error {
}
go s.HandleGlobalRequest(reqs)
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
}
tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil {
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
if err != nil {
return err
}
if err := s.lifecycle.Close(); err != nil {
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
}
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.GetUser() == "UNAUTHORIZED" {
if err := tcpipReq.Reply(false, nil); err != nil {
log.Printf("cannot reply to tcpip req: %s\n", err)
return err
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
}
return nil
}
s.HandleTCPIPForward(tcpipReq)
s.interaction.Start()
s.lifecycle.GetConnection().Wait()
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err

View File

@@ -8,6 +8,13 @@ const (
SETUP Status = "SETUP"
)
type Mode string
const (
INTERACTIVE Mode = "INTERACTIVE"
HEADLESS Mode = "HEADLESS"
)
type TunnelType string
const (