revert-54069ad305 #11
@@ -57,7 +57,6 @@ func (tt *tcp) handleTcp(conn net.Conn) {
|
|||||||
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
|
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -189,42 +187,20 @@ func (f *forwarder) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||||
var buf bytes.Buffer
|
host, portStr, _ := net.SplitHostPort(origin.String())
|
||||||
|
|
||||||
host, originPort := parseAddr(origin.String())
|
|
||||||
|
|
||||||
writeSSHString(&buf, "localhost")
|
|
||||||
err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
writeSSHString(&buf, host)
|
|
||||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseAddr(addr string) (string, uint16) {
|
|
||||||
host, portStr, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
|
||||||
return "0.0.0.0", uint16(0)
|
|
||||||
}
|
|
||||||
port, _ := strconv.Atoi(portStr)
|
port, _ := strconv.Atoi(portStr)
|
||||||
return host, uint16(port)
|
|
||||||
|
forwardPayload := struct {
|
||||||
|
DestAddr string
|
||||||
|
DestPort uint32
|
||||||
|
OriginAddr string
|
||||||
|
OriginPort uint32
|
||||||
|
}{
|
||||||
|
DestAddr: "localhost",
|
||||||
|
DestPort: uint32(f.ForwardedPort()),
|
||||||
|
OriginAddr: host,
|
||||||
|
OriginPort: uint32(port),
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
return ssh.Marshal(forwardPayload)
|
||||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write string to buffer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buffer.WriteString(str)
|
|
||||||
}
|
}
|
||||||
|
|||||||
+31
-64
@@ -1,7 +1,6 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -173,13 +172,11 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleMissingForwardRequest() error {
|
func (s *session) handleMissingForwardRequest() error {
|
||||||
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = s.lifecycle.Close(); err != nil {
|
|
||||||
log.Printf("failed to close session: %v", err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("no forwarding Request")
|
return fmt.Errorf("no forwarding Request")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,8 +236,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
|||||||
for req := range GlobalRequest {
|
for req := range GlobalRequest {
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
case "shell", "pty-req":
|
case "shell", "pty-req":
|
||||||
err := req.Reply(true, nil)
|
if err := req.Reply(true, nil); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case "window-change":
|
case "window-change":
|
||||||
@@ -249,8 +245,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
|||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
log.Println("Unknown request type:", req.Type)
|
log.Println("Unknown request type:", req.Type)
|
||||||
err := req.Reply(false, nil)
|
if err := req.Reply(false, nil); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,24 +253,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
|
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
|
||||||
address, err = readSSHString(payloadReader)
|
var forwardPayload struct {
|
||||||
if err != nil {
|
BindAddr string
|
||||||
return "", 0, err
|
BindPort uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
var rawPortToBind uint32
|
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
|
||||||
if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
|
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
|
||||||
return "", 0, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rawPortToBind > 65535 {
|
if forwardPayload.BindPort > 65535 {
|
||||||
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
|
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
|
||||||
}
|
}
|
||||||
|
|
||||||
port = uint16(rawPortToBind)
|
port = uint16(forwardPayload.BindPort)
|
||||||
|
|
||||||
if isBlockedPort(port) {
|
if isBlockedPort(port) {
|
||||||
return "", 0, fmt.Errorf("port is block")
|
return "", 0, fmt.Errorf("port is blocked")
|
||||||
}
|
}
|
||||||
|
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
@@ -283,10 +278,10 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string,
|
|||||||
if !ok {
|
if !ok {
|
||||||
return "", 0, fmt.Errorf("no available port")
|
return "", 0, fmt.Errorf("no available port")
|
||||||
}
|
}
|
||||||
return address, unassigned, err
|
return forwardPayload.BindAddr, unassigned, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return address, port, err
|
return forwardPayload.BindAddr, port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
|
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
|
||||||
@@ -294,37 +289,25 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey,
|
|||||||
if key != nil {
|
if key != nil {
|
||||||
s.registry.Remove(*key)
|
s.registry.Remove(*key)
|
||||||
}
|
}
|
||||||
|
|
||||||
if listener != nil {
|
if listener != nil {
|
||||||
if err := listener.Close(); err != nil {
|
errs = append(errs, listener.Close())
|
||||||
errs = append(errs, fmt.Errorf("close listener: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := req.Reply(false, nil); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("reply request: %w", err))
|
|
||||||
}
|
|
||||||
if err := s.lifecycle.Close(); err != nil {
|
|
||||||
errs = append(errs, fmt.Errorf("close session: %w", err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
errs = append(errs, req.Reply(false, nil))
|
||||||
|
errs = append(errs, s.lifecycle.Close())
|
||||||
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
|
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
|
||||||
return errors.Join(errs...)
|
return errors.Join(errs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
err = binary.Write(buf, binary.BigEndian, uint32(port))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = req.Reply(true, buf.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
|
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
|
||||||
err := s.approveForwardingRequest(req, portToBind)
|
replyPayload := struct {
|
||||||
|
BoundPort uint32
|
||||||
|
}{
|
||||||
|
BoundPort: uint32(portToBind),
|
||||||
|
}
|
||||||
|
err := req.Reply(true, ssh.Marshal(replyPayload))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -342,9 +325,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
||||||
reader := bytes.NewReader(req.Payload)
|
address, port, err := s.parseForwardPayload(req.Payload)
|
||||||
|
|
||||||
address, port, err := s.parseForwardPayload(reader)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
|
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
|
||||||
}
|
}
|
||||||
@@ -376,13 +357,13 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
|||||||
|
|
||||||
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
|
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
|
||||||
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
|
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||||
}
|
}
|
||||||
|
|
||||||
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
|
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
|
||||||
listener, err := tcpServer.Listen()
|
listener, err := tcpServer.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||||
}
|
}
|
||||||
|
|
||||||
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
|
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
|
||||||
@@ -405,20 +386,6 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readSSHString(reader io.Reader) (string, error) {
|
|
||||||
var length uint32
|
|
||||||
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
strBytes := make([]byte, length)
|
|
||||||
if length > 0 {
|
|
||||||
if _, err := io.ReadFull(reader, strBytes); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return string(strBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isBlockedPort(port uint16) bool {
|
func isBlockedPort(port uint16) bool {
|
||||||
if port == 80 || port == 443 {
|
if port == 80 || port == 443 {
|
||||||
return false
|
return false
|
||||||
|
|||||||
+6
-76
@@ -1,7 +1,6 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
@@ -287,52 +286,6 @@ func TestIsBlockedPort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadSSHString(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input []byte
|
|
||||||
want string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid string",
|
|
||||||
input: append([]byte{0, 0, 0, 4}, []byte("test")...),
|
|
||||||
want: "test",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty string",
|
|
||||||
input: []byte{0, 0, 0, 0},
|
|
||||||
want: "",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "short length",
|
|
||||||
input: []byte{0, 0, 0},
|
|
||||||
want: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing payload",
|
|
||||||
input: []byte{0, 0, 0, 4, 'a', 'b'},
|
|
||||||
want: "",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := readSSHString(bytes.NewReader(tt.input))
|
|
||||||
if tt.wantErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.want, got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleGlobalRequest(t *testing.T) {
|
func TestHandleGlobalRequest(t *testing.T) {
|
||||||
_, sReqs, _, cConn, cleanup := setupSSH(t)
|
_, sReqs, _, cConn, cleanup := setupSSH(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
@@ -1033,7 +986,7 @@ func TestParseForwardPayload_Errors(t *testing.T) {
|
|||||||
s := &session{}
|
s := &session{}
|
||||||
|
|
||||||
t.Run("Short Address", func(t *testing.T) {
|
t.Run("Short Address", func(t *testing.T) {
|
||||||
_, _, err := s.parseForwardPayload(bytes.NewReader([]byte{0, 0, 0, 4}))
|
_, _, err := s.parseForwardPayload([]byte{0, 0, 0, 4})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error, got nil")
|
t.Error("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -1041,7 +994,7 @@ func TestParseForwardPayload_Errors(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Short Port", func(t *testing.T) {
|
t.Run("Short Port", func(t *testing.T) {
|
||||||
payload := append([]byte{0, 0, 0, 4}, []byte("addr")...)
|
payload := append([]byte{0, 0, 0, 4}, []byte("addr")...)
|
||||||
_, _, err := s.parseForwardPayload(bytes.NewReader(payload))
|
_, _, err := s.parseForwardPayload(payload)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error, got nil")
|
t.Error("expected error, got nil")
|
||||||
}
|
}
|
||||||
@@ -1052,7 +1005,7 @@ func TestParseForwardPayload_Errors(t *testing.T) {
|
|||||||
portBuf := make([]byte, 4)
|
portBuf := make([]byte, 4)
|
||||||
binary.BigEndian.PutUint32(portBuf, 22)
|
binary.BigEndian.PutUint32(portBuf, 22)
|
||||||
payload = append(payload, portBuf...)
|
payload = append(payload, portBuf...)
|
||||||
_, _, err := s.parseForwardPayload(bytes.NewReader(payload))
|
_, _, err := s.parseForwardPayload(payload)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error, got nil")
|
t.Error("expected error, got nil")
|
||||||
} else if !strings.Contains(err.Error(), "port is block") {
|
} else if !strings.Contains(err.Error(), "port is block") {
|
||||||
@@ -1160,11 +1113,7 @@ func TestDenyForwardingRequest_Full(t *testing.T) {
|
|||||||
req := getReq(t, cConn, sReqs)
|
req := getReq(t, cConn, sReqs)
|
||||||
mCloser := &mockCloser{err: fmt.Errorf("close error")}
|
mCloser := &mockCloser{err: fmt.Errorf("close error")}
|
||||||
err := s.denyForwardingRequest(req, nil, mCloser, "error")
|
err := s.denyForwardingRequest(req, nil, mCloser, "error")
|
||||||
if err == nil {
|
assert.Error(t, err, net.ErrClosed)
|
||||||
t.Error("expected error, got nil")
|
|
||||||
} else if !strings.Contains(err.Error(), "close listener: close error") {
|
|
||||||
t.Errorf("expected error to contain %q, got %q", "close listener: close error", err.Error())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Reply error", func(t *testing.T) {
|
t.Run("Reply error", func(t *testing.T) {
|
||||||
@@ -1174,27 +1123,8 @@ func TestDenyForwardingRequest_Full(t *testing.T) {
|
|||||||
cConn.Close()
|
cConn.Close()
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
err := s.denyForwardingRequest(req, nil, nil, "error")
|
err := s.denyForwardingRequest(req, nil, nil, assert.AnError.Error())
|
||||||
if err == nil {
|
assert.Error(t, err, assert.AnError)
|
||||||
t.Error("expected error, got nil")
|
|
||||||
} else if !strings.Contains(err.Error(), "reply request") {
|
|
||||||
t.Errorf("expected error to contain %q, got %q", "reply request", err.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Lifecycle Close error", func(t *testing.T) {
|
|
||||||
s, _, _, sReqs, cConn, cleanup := setup(t)
|
|
||||||
defer cleanup()
|
|
||||||
req := getReq(t, cConn, sReqs)
|
|
||||||
mLife := &mockLifecycle{closeErr: fmt.Errorf("life close error")}
|
|
||||||
s.lifecycle = mLife
|
|
||||||
|
|
||||||
err := s.denyForwardingRequest(req, nil, nil, "error")
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error, got nil")
|
|
||||||
} else if !strings.Contains(err.Error(), "close session: life close error") {
|
|
||||||
t.Errorf("expected error to contain %q, got %q", "close session: life close error", err.Error())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user