refactor: consolidate error handling with fail() function in session handlers
All checks were successful
Docker Build and Push / build-and-push-tags (push) Successful in 3m21s
Docker Build and Push / build-and-push-branches (push) Has been skipped

- Replace repetitive error handling code with fail() function in HandleGlobalRequest
- Standardize error response pattern across all handler methods
- Improve code maintainability and reduce duplication
This commit is contained in:
2026-01-12 14:42:42 +07:00
parent e1f5d73e03
commit 560c98b869
2 changed files with 86 additions and 140 deletions

View File

@@ -74,7 +74,6 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) {
for _, port := range pm.sortedPorts { for _, port := range pm.sortedPorts {
if !pm.ports[port] { if !pm.ports[port] {
pm.ports[port] = true
return port, true return port, true
} }
} }

View File

@@ -59,142 +59,79 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected") log.Println("Port forwarding request detected")
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
reader := bytes.NewReader(req.Payload) reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader) addr, err := readSSHString(reader)
if err != nil { if err != nil {
log.Println("Failed to read address from payload:", err) fail(fmt.Sprintf("Failed to read address from payload: %v", err))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
var rawPortToBind uint32 var rawPortToBind uint32
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
log.Println("Failed to read port from payload:", err) fail(fmt.Sprintf("Failed to read port from payload: %v", err))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
if rawPortToBind > 65535 { if rawPortToBind > 65535 {
log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind) fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
portToBind := uint16(rawPortToBind) portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) { if isBlockedPort(portToBind) {
log.Printf("Port %d is blocked or restricted", portToBind) fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
if portToBind == 80 || portToBind == 443 { switch portToBind {
case 80, 443:
s.HandleHTTPForward(req, portToBind) s.HandleHTTPForward(req, portToBind)
return default:
s.HandleTCPForward(req, addr, portToBind)
} }
if portToBind == 0 {
unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign
if !success {
log.Println("No available port")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
log.Printf("Port %d is already in use or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
err = portUtil.Default.SetPortStatus(portToBind, true)
if err != nil {
log.Println("Failed to set port status:", err)
return
}
s.HandleTCPForward(req, addr, portToBind)
} }
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
fail := func(msg string, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
}
slug := random.GenerateRandomString(20) slug := random.GenerateRandomString(20)
key := types.SessionKey{Id: slug, Type: types.HTTP} key := types.SessionKey{Id: slug, Type: types.HTTP}
if !s.registry.Register(key, s) { if !s.registry.Register(key, s) {
log.Printf("Failed to register client with slug: %s", slug) fail(fmt.Sprintf("Failed to register client with slug: %s", slug), nil)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
log.Println("Failed to write port to buffer:", err) fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key)
s.registry.Remove(key)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
log.Printf("HTTP forwarding approved on port: %d", portToBind) log.Printf("HTTP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) fail(fmt.Sprintf("Failed to reply to request: %v", err), &key)
s.registry.Remove(key)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
@@ -205,72 +142,82 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
} }
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
log.Printf("Requested forwarding on %s:%d", addr, portToBind) fail := func(msg string) {
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) log.Println(msg)
if err != nil { if err := req.Reply(false, nil); err != nil {
log.Printf("Port %d is already in use or restricted", portToBind)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.lifecycle.Close() if err := s.lifecycle.Close(); err != nil {
if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
}
cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if port != 0 {
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
}
if listener != nil {
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %v", closeErr)
}
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
}
if portToBind == 0 {
unassigned, ok := portUtil.Default.GetUnassignedPort()
if !ok {
fail("No available port")
return
}
portToBind = unassigned
}
if isUsed, exists := portUtil.Default.GetPortStatus(portToBind); exists && isUsed {
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
return
}
if err := portUtil.Default.SetPortStatus(portToBind, true); err != nil {
fail(fmt.Sprintf("Failed to set port status: %v", err))
return
}
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil)
return return
} }
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
if !s.registry.Register(key, s) { if !s.registry.Register(key, s) {
log.Printf("Failed to register TCP client with id: %s", key.Id) cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %s", closeErr)
}
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
return return
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
log.Println("Failed to write port to buffer:", err) cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key)
s.registry.Remove(key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return return
} }
log.Printf("TCP forwarding approved on port: %d", portToBind) log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key)
s.registry.Remove(key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return return
} }