refactor: remove custom parsing functions and use ssh.Marshal/ssh.Unmarshal for serialization
SonarQube Scan / SonarQube Trigger (push) Successful in 2m14s

This commit is contained in:
2026-01-25 12:21:25 +07:00
parent e59fea6604
commit 2b488a5ab5
4 changed files with 50 additions and 178 deletions
+31 -64
View File
@@ -1,7 +1,6 @@
package session
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@@ -173,13 +172,11 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
}
func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil {
return err
}
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
}
@@ -239,8 +236,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
err := req.Reply(true, nil)
if err != nil {
if err := req.Reply(true, nil); err != nil {
return err
}
case "window-change":
@@ -249,8 +245,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
}
default:
log.Println("Unknown request type:", req.Type)
err := req.Reply(false, nil)
if err != nil {
if err := req.Reply(false, nil); err != nil {
return err
}
}
@@ -258,24 +253,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
return nil
}
func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
address, err = readSSHString(payloadReader)
if err != nil {
return "", 0, err
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
var forwardPayload struct {
BindAddr string
BindPort uint32
}
var rawPortToBind uint32
if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
return "", 0, err
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
}
if rawPortToBind > 65535 {
if forwardPayload.BindPort > 65535 {
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
}
port = uint16(rawPortToBind)
port = uint16(forwardPayload.BindPort)
if isBlockedPort(port) {
return "", 0, fmt.Errorf("port is block")
return "", 0, fmt.Errorf("port is blocked")
}
if port == 0 {
@@ -283,10 +278,10 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string,
if !ok {
return "", 0, fmt.Errorf("no available port")
}
return address, unassigned, err
return forwardPayload.BindAddr, unassigned, nil
}
return address, port, err
return forwardPayload.BindAddr, port, nil
}
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
@@ -294,37 +289,25 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey,
if key != nil {
s.registry.Remove(*key)
}
if listener != nil {
if err := listener.Close(); err != nil {
errs = append(errs, fmt.Errorf("close listener: %w", err))
}
}
if err := req.Reply(false, nil); err != nil {
errs = append(errs, fmt.Errorf("reply request: %w", err))
}
if err := s.lifecycle.Close(); err != nil {
errs = append(errs, fmt.Errorf("close session: %w", err))
errs = append(errs, listener.Close())
}
errs = append(errs, req.Reply(false, nil))
errs = append(errs, s.lifecycle.Close())
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
return errors.Join(errs...)
}
func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(port))
if err != nil {
return err
}
err = req.Reply(true, buf.Bytes())
if err != nil {
return err
}
return nil
}
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
err := s.approveForwardingRequest(req, portToBind)
replyPayload := struct {
BoundPort uint32
}{
BoundPort: uint32(portToBind),
}
err := req.Reply(true, ssh.Marshal(replyPayload))
if err != nil {
return err
}
@@ -342,9 +325,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
}
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
reader := bytes.NewReader(req.Payload)
address, port, err := s.parseForwardPayload(reader)
address, port, err := s.parseForwardPayload(req.Payload)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
}
@@ -376,13 +357,13 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
}
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
listener, err := tcpServer.Listen()
if err != nil {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
@@ -405,20 +386,6 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
return nil
}
func readSSHString(reader io.Reader) (string, error) {
var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return "", err
}
strBytes := make([]byte, length)
if length > 0 {
if _, err := io.ReadFull(reader, strBytes); err != nil {
return "", err
}
}
return string(strBytes), nil
}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
return false