From b8c6359820dc201e9df3b4f52a16e317227ecb1c Mon Sep 17 00:00:00 2001 From: bagas Date: Sun, 25 Jan 2026 12:21:25 +0700 Subject: [PATCH] refactor: remove custom parsing functions and use ssh.Marshal/ssh.Unmarshal for serialization --- internal/transport/tcp.go | 1 - session/forwarder/forwarder.go | 50 +++++------------- session/session.go | 95 +++++++++++----------------------- session/session_test.go | 82 +++-------------------------- 4 files changed, 50 insertions(+), 178 deletions(-) diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go index 91ab0b0..9ea2354 100644 --- a/internal/transport/tcp.go +++ b/internal/transport/tcp.go @@ -57,7 +57,6 @@ func (tt *tcp) handleTcp(conn net.Conn) { channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) - return } diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index c602565..43bde3e 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -1,8 +1,6 @@ package forwarder import ( - "bytes" - "encoding/binary" "errors" "fmt" "io" @@ -189,42 +187,20 @@ func (f *forwarder) Close() error { } func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { - var buf bytes.Buffer - - host, originPort := parseAddr(origin.String()) - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort())) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - - return buf.Bytes() -} - -func parseAddr(addr string) (string, uint16) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint16(0) - } + host, portStr, _ := net.SplitHostPort(origin.String()) port, _ := strconv.Atoi(portStr) - return host, uint16(port) -} -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return + forwardPayload := struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 + }{ + DestAddr: "localhost", + DestPort: uint32(f.ForwardedPort()), + OriginAddr: host, + OriginPort: uint32(port), } - buffer.WriteString(str) + + return ssh.Marshal(forwardPayload) } diff --git a/session/session.go b/session/session.go index e5d4cc2..cc27c4c 100644 --- a/session/session.go +++ b/session/session.go @@ -1,7 +1,6 @@ package session import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -173,13 +172,11 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error { } func (s *session) handleMissingForwardRequest() error { - err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort())) + err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort())) if err != nil { return err } - if err = s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } + return fmt.Errorf("no forwarding Request") } @@ -239,8 +236,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { for req := range GlobalRequest { switch req.Type { case "shell", "pty-req": - err := req.Reply(true, nil) - if err != nil { + if err := req.Reply(true, nil); err != nil { return err } case "window-change": @@ -249,8 +245,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { } default: log.Println("Unknown request type:", req.Type) - err := req.Reply(false, nil) - if err != nil { + if err := req.Reply(false, nil); err != nil { return err } } @@ -258,24 +253,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { return nil } -func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) { - address, err = readSSHString(payloadReader) - if err != nil { - return "", 0, err +func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) { + var forwardPayload struct { + BindAddr string + BindPort uint32 } - var rawPortToBind uint32 - if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil { - return "", 0, err + if err = ssh.Unmarshal(payload, &forwardPayload); err != nil { + return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err) } - if rawPortToBind > 65535 { + if forwardPayload.BindPort > 65535 { return "", 0, fmt.Errorf("port is larger than allowed port of 65535") } - port = uint16(rawPortToBind) + port = uint16(forwardPayload.BindPort) + if isBlockedPort(port) { - return "", 0, fmt.Errorf("port is block") + return "", 0, fmt.Errorf("port is blocked") } if port == 0 { @@ -283,10 +278,10 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, if !ok { return "", 0, fmt.Errorf("no available port") } - return address, unassigned, err + return forwardPayload.BindAddr, unassigned, nil } - return address, port, err + return forwardPayload.BindAddr, port, nil } func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error { @@ -294,37 +289,25 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, if key != nil { s.registry.Remove(*key) } + if listener != nil { - if err := listener.Close(); err != nil { - errs = append(errs, fmt.Errorf("close listener: %w", err)) - } - } - if err := req.Reply(false, nil); err != nil { - errs = append(errs, fmt.Errorf("reply request: %w", err)) - } - if err := s.lifecycle.Close(); err != nil { - errs = append(errs, fmt.Errorf("close session: %w", err)) + errs = append(errs, listener.Close()) } + + errs = append(errs, req.Reply(false, nil)) + errs = append(errs, s.lifecycle.Close()) errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg)) return errors.Join(errs...) } -func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) { - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.BigEndian, uint32(port)) - if err != nil { - return err - } - - err = req.Reply(true, buf.Bytes()) - if err != nil { - return err - } - return nil -} - func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error { - err := s.approveForwardingRequest(req, portToBind) + replyPayload := struct { + BoundPort uint32 + }{ + BoundPort: uint32(portToBind), + } + err := req.Reply(true, ssh.Marshal(replyPayload)) + if err != nil { return err } @@ -342,9 +325,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen } func (s *session) HandleTCPIPForward(req *ssh.Request) error { - reader := bytes.NewReader(req.Payload) - - address, port, err := s.parseForwardPayload(reader) + address, port, err := s.parseForwardPayload(req.Payload) if err != nil { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error())) } @@ -376,13 +357,13 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed { - return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind)) } tcpServer := transport.NewTCPServer(portToBind, s.forwarder) listener, err := tcpServer.Listen() if err != nil { - return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind)) } key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP} @@ -405,20 +386,6 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin return nil } -func readSSHString(reader io.Reader) (string, error) { - var length uint32 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - return "", err - } - strBytes := make([]byte, length) - if length > 0 { - if _, err := io.ReadFull(reader, strBytes); err != nil { - return "", err - } - } - return string(strBytes), nil -} - func isBlockedPort(port uint16) bool { if port == 80 || port == 443 { return false diff --git a/session/session_test.go b/session/session_test.go index 3d41d04..0e87834 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1,7 +1,6 @@ package session import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -287,52 +286,6 @@ func TestIsBlockedPort(t *testing.T) { } } -func TestReadSSHString(t *testing.T) { - tests := []struct { - name string - input []byte - want string - wantErr bool - }{ - { - name: "valid string", - input: append([]byte{0, 0, 0, 4}, []byte("test")...), - want: "test", - wantErr: false, - }, - { - name: "empty string", - input: []byte{0, 0, 0, 0}, - want: "", - wantErr: false, - }, - { - name: "short length", - input: []byte{0, 0, 0}, - want: "", - wantErr: true, - }, - { - name: "missing payload", - input: []byte{0, 0, 0, 4, 'a', 'b'}, - want: "", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := readSSHString(bytes.NewReader(tt.input)) - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.want, got) - } - }) - } -} - func TestHandleGlobalRequest(t *testing.T) { _, sReqs, _, cConn, cleanup := setupSSH(t) defer cleanup() @@ -1033,7 +986,7 @@ func TestParseForwardPayload_Errors(t *testing.T) { s := &session{} t.Run("Short Address", func(t *testing.T) { - _, _, err := s.parseForwardPayload(bytes.NewReader([]byte{0, 0, 0, 4})) + _, _, err := s.parseForwardPayload([]byte{0, 0, 0, 4}) if err == nil { t.Error("expected error, got nil") } @@ -1041,7 +994,7 @@ func TestParseForwardPayload_Errors(t *testing.T) { t.Run("Short Port", func(t *testing.T) { payload := append([]byte{0, 0, 0, 4}, []byte("addr")...) - _, _, err := s.parseForwardPayload(bytes.NewReader(payload)) + _, _, err := s.parseForwardPayload(payload) if err == nil { t.Error("expected error, got nil") } @@ -1052,7 +1005,7 @@ func TestParseForwardPayload_Errors(t *testing.T) { portBuf := make([]byte, 4) binary.BigEndian.PutUint32(portBuf, 22) payload = append(payload, portBuf...) - _, _, err := s.parseForwardPayload(bytes.NewReader(payload)) + _, _, err := s.parseForwardPayload(payload) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "port is block") { @@ -1160,11 +1113,7 @@ func TestDenyForwardingRequest_Full(t *testing.T) { req := getReq(t, cConn, sReqs) mCloser := &mockCloser{err: fmt.Errorf("close error")} err := s.denyForwardingRequest(req, nil, mCloser, "error") - if err == nil { - t.Error("expected error, got nil") - } else if !strings.Contains(err.Error(), "close listener: close error") { - t.Errorf("expected error to contain %q, got %q", "close listener: close error", err.Error()) - } + assert.Error(t, err, net.ErrClosed) }) t.Run("Reply error", func(t *testing.T) { @@ -1174,27 +1123,8 @@ func TestDenyForwardingRequest_Full(t *testing.T) { cConn.Close() time.Sleep(100 * time.Millisecond) - err := s.denyForwardingRequest(req, nil, nil, "error") - if err == nil { - t.Error("expected error, got nil") - } else if !strings.Contains(err.Error(), "reply request") { - t.Errorf("expected error to contain %q, got %q", "reply request", err.Error()) - } - }) - - t.Run("Lifecycle Close error", func(t *testing.T) { - s, _, _, sReqs, cConn, cleanup := setup(t) - defer cleanup() - req := getReq(t, cConn, sReqs) - mLife := &mockLifecycle{closeErr: fmt.Errorf("life close error")} - s.lifecycle = mLife - - err := s.denyForwardingRequest(req, nil, nil, "error") - if err == nil { - t.Error("expected error, got nil") - } else if !strings.Contains(err.Error(), "close session: life close error") { - t.Errorf("expected error to contain %q, got %q", "close session: life close error", err.Error()) - } + err := s.denyForwardingRequest(req, nil, nil, assert.AnError.Error()) + assert.Error(t, err, assert.AnError) }) }