refactor: remove custom parsing functions and use ssh.Marshal/ssh.Unmarshal for serialization
This commit is contained in:
+6
-76
@@ -1,7 +1,6 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
@@ -287,52 +286,6 @@ func TestIsBlockedPort(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadSSHString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid string",
|
||||
input: append([]byte{0, 0, 0, 4}, []byte("test")...),
|
||||
want: "test",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: []byte{0, 0, 0, 0},
|
||||
want: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short length",
|
||||
input: []byte{0, 0, 0},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing payload",
|
||||
input: []byte{0, 0, 0, 4, 'a', 'b'},
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := readSSHString(bytes.NewReader(tt.input))
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGlobalRequest(t *testing.T) {
|
||||
_, sReqs, _, cConn, cleanup := setupSSH(t)
|
||||
defer cleanup()
|
||||
@@ -1033,7 +986,7 @@ func TestParseForwardPayload_Errors(t *testing.T) {
|
||||
s := &session{}
|
||||
|
||||
t.Run("Short Address", func(t *testing.T) {
|
||||
_, _, err := s.parseForwardPayload(bytes.NewReader([]byte{0, 0, 0, 4}))
|
||||
_, _, err := s.parseForwardPayload([]byte{0, 0, 0, 4})
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
@@ -1041,7 +994,7 @@ func TestParseForwardPayload_Errors(t *testing.T) {
|
||||
|
||||
t.Run("Short Port", func(t *testing.T) {
|
||||
payload := append([]byte{0, 0, 0, 4}, []byte("addr")...)
|
||||
_, _, err := s.parseForwardPayload(bytes.NewReader(payload))
|
||||
_, _, err := s.parseForwardPayload(payload)
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
@@ -1052,7 +1005,7 @@ func TestParseForwardPayload_Errors(t *testing.T) {
|
||||
portBuf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(portBuf, 22)
|
||||
payload = append(payload, portBuf...)
|
||||
_, _, err := s.parseForwardPayload(bytes.NewReader(payload))
|
||||
_, _, err := s.parseForwardPayload(payload)
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), "port is block") {
|
||||
@@ -1160,11 +1113,7 @@ func TestDenyForwardingRequest_Full(t *testing.T) {
|
||||
req := getReq(t, cConn, sReqs)
|
||||
mCloser := &mockCloser{err: fmt.Errorf("close error")}
|
||||
err := s.denyForwardingRequest(req, nil, mCloser, "error")
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), "close listener: close error") {
|
||||
t.Errorf("expected error to contain %q, got %q", "close listener: close error", err.Error())
|
||||
}
|
||||
assert.Error(t, err, net.ErrClosed)
|
||||
})
|
||||
|
||||
t.Run("Reply error", func(t *testing.T) {
|
||||
@@ -1174,27 +1123,8 @@ func TestDenyForwardingRequest_Full(t *testing.T) {
|
||||
cConn.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := s.denyForwardingRequest(req, nil, nil, "error")
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), "reply request") {
|
||||
t.Errorf("expected error to contain %q, got %q", "reply request", err.Error())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Lifecycle Close error", func(t *testing.T) {
|
||||
s, _, _, sReqs, cConn, cleanup := setup(t)
|
||||
defer cleanup()
|
||||
req := getReq(t, cConn, sReqs)
|
||||
mLife := &mockLifecycle{closeErr: fmt.Errorf("life close error")}
|
||||
s.lifecycle = mLife
|
||||
|
||||
err := s.denyForwardingRequest(req, nil, nil, "error")
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), "close session: life close error") {
|
||||
t.Errorf("expected error to contain %q, got %q", "close session: life close error", err.Error())
|
||||
}
|
||||
err := s.denyForwardingRequest(req, nil, nil, assert.AnError.Error())
|
||||
assert.Error(t, err, assert.AnError)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user