refactor: remove custom parsing functions and use ssh.Marshal/ssh.Unmarshal for serialization
This commit is contained in:
+31
-64
@@ -1,7 +1,6 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -173,13 +172,11 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
|
||||
}
|
||||
|
||||
func (s *session) handleMissingForwardRequest() error {
|
||||
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("no forwarding Request")
|
||||
}
|
||||
|
||||
@@ -239,8 +236,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
||||
for req := range GlobalRequest {
|
||||
switch req.Type {
|
||||
case "shell", "pty-req":
|
||||
err := req.Reply(true, nil)
|
||||
if err != nil {
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
case "window-change":
|
||||
@@ -249,8 +245,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
||||
}
|
||||
default:
|
||||
log.Println("Unknown request type:", req.Type)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
if err := req.Reply(false, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -258,24 +253,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
|
||||
address, err = readSSHString(payloadReader)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
|
||||
var forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
var rawPortToBind uint32
|
||||
if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
|
||||
return "", 0, err
|
||||
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
|
||||
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
|
||||
}
|
||||
|
||||
if rawPortToBind > 65535 {
|
||||
if forwardPayload.BindPort > 65535 {
|
||||
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
|
||||
}
|
||||
|
||||
port = uint16(rawPortToBind)
|
||||
port = uint16(forwardPayload.BindPort)
|
||||
|
||||
if isBlockedPort(port) {
|
||||
return "", 0, fmt.Errorf("port is block")
|
||||
return "", 0, fmt.Errorf("port is blocked")
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
@@ -283,10 +278,10 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string,
|
||||
if !ok {
|
||||
return "", 0, fmt.Errorf("no available port")
|
||||
}
|
||||
return address, unassigned, err
|
||||
return forwardPayload.BindAddr, unassigned, nil
|
||||
}
|
||||
|
||||
return address, port, err
|
||||
return forwardPayload.BindAddr, port, nil
|
||||
}
|
||||
|
||||
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
|
||||
@@ -294,37 +289,25 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey,
|
||||
if key != nil {
|
||||
s.registry.Remove(*key)
|
||||
}
|
||||
|
||||
if listener != nil {
|
||||
if err := listener.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("close listener: %w", err))
|
||||
}
|
||||
}
|
||||
if err := req.Reply(false, nil); err != nil {
|
||||
errs = append(errs, fmt.Errorf("reply request: %w", err))
|
||||
}
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("close session: %w", err))
|
||||
errs = append(errs, listener.Close())
|
||||
}
|
||||
|
||||
errs = append(errs, req.Reply(false, nil))
|
||||
errs = append(errs, s.lifecycle.Close())
|
||||
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.BigEndian, uint32(port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.Reply(true, buf.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
|
||||
err := s.approveForwardingRequest(req, portToBind)
|
||||
replyPayload := struct {
|
||||
BoundPort uint32
|
||||
}{
|
||||
BoundPort: uint32(portToBind),
|
||||
}
|
||||
err := req.Reply(true, ssh.Marshal(replyPayload))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -342,9 +325,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
|
||||
}
|
||||
|
||||
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
||||
reader := bytes.NewReader(req.Payload)
|
||||
|
||||
address, port, err := s.parseForwardPayload(reader)
|
||||
address, port, err := s.parseForwardPayload(req.Payload)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
|
||||
}
|
||||
@@ -376,13 +357,13 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
||||
|
||||
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
|
||||
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
}
|
||||
|
||||
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
|
||||
listener, err := tcpServer.Listen()
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
}
|
||||
|
||||
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
|
||||
@@ -405,20 +386,6 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
return nil
|
||||
}
|
||||
|
||||
func readSSHString(reader io.Reader) (string, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
||||
return "", err
|
||||
}
|
||||
strBytes := make([]byte, length)
|
||||
if length > 0 {
|
||||
if _, err := io.ReadFull(reader, strBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return string(strBytes), nil
|
||||
}
|
||||
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
|
||||
Reference in New Issue
Block a user