refactor(registry): define reusable constant errors
SonarQube Scan / SonarQube Trigger (push) Successful in 1m3s
SonarQube Scan / SonarQube Trigger (push) Successful in 1m3s
- Introduced package-level error variables in registry to replace repeated fmt.Errorf calls - Added errors like ErrSessionNotFound, ErrSlugInUse, ErrInvalidSlug, ErrForbiddenSlug, ErrSlugChangeNotAllowed, and ErrSlugUnchanged
This commit is contained in:
@@ -34,6 +34,15 @@ type registry struct {
|
|||||||
slugIndex map[Key]string
|
slugIndex map[Key]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSessionNotFound = fmt.Errorf("session not found")
|
||||||
|
ErrSlugInUse = fmt.Errorf("slug already in use")
|
||||||
|
ErrInvalidSlug = fmt.Errorf("invalid slug")
|
||||||
|
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
|
||||||
|
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
|
||||||
|
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
|
||||||
|
)
|
||||||
|
|
||||||
func NewRegistry() Registry {
|
func NewRegistry() Registry {
|
||||||
return ®istry{
|
return ®istry{
|
||||||
byUser: make(map[string]map[Key]Session),
|
byUser: make(map[string]map[Key]Session),
|
||||||
@@ -47,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
|
|||||||
|
|
||||||
userID, ok := r.slugIndex[key]
|
userID, ok := r.slugIndex[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("session not found")
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
client, ok := r.byUser[userID][key]
|
client, ok := r.byUser[userID][key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("session not found")
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
@@ -63,37 +72,37 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
|
|||||||
|
|
||||||
client, ok := r.byUser[user][key]
|
client, ok := r.byUser[user][key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("session not found")
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||||
if oldKey.Type != newKey.Type {
|
if oldKey.Type != newKey.Type {
|
||||||
return fmt.Errorf("tunnel type cannot change")
|
return ErrSlugUnchanged
|
||||||
}
|
}
|
||||||
|
|
||||||
if newKey.Type != types.TunnelTypeHTTP {
|
if newKey.Type != types.TunnelTypeHTTP {
|
||||||
return fmt.Errorf("non http tunnel cannot change slug")
|
return ErrSlugChangeNotAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
if isForbiddenSlug(newKey.Id) {
|
if isForbiddenSlug(newKey.Id) {
|
||||||
return fmt.Errorf("this subdomain is reserved. Please choose a different one")
|
return ErrForbiddenSlug
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isValidSlug(newKey.Id) {
|
if !isValidSlug(newKey.Id) {
|
||||||
return fmt.Errorf("invalid subdomain. Follow the rules")
|
return ErrInvalidSlug
|
||||||
}
|
}
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||||
return fmt.Errorf("someone already uses this subdomain")
|
return ErrSlugInUse
|
||||||
}
|
}
|
||||||
client, ok := r.byUser[user][oldKey]
|
client, ok := r.byUser[user][oldKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("session not found")
|
return ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(r.byUser[user], oldKey)
|
delete(r.byUser[user], oldKey)
|
||||||
|
|||||||
@@ -0,0 +1,632 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
"tunnel_pls/internal/port"
|
||||||
|
"tunnel_pls/session/forwarder"
|
||||||
|
"tunnel_pls/session/interaction"
|
||||||
|
"tunnel_pls/session/lifecycle"
|
||||||
|
"tunnel_pls/session/slug"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockSession struct{ user string }
|
||||||
|
|
||||||
|
func (m *mockSession) Lifecycle() lifecycle.Lifecycle { return &mockLifecycle{user: m.user} }
|
||||||
|
func (m *mockSession) Interaction() interaction.Interaction {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockSession) Slug() slug.Slug {
|
||||||
|
return &mockSlug{}
|
||||||
|
}
|
||||||
|
func (m *mockSession) Detail() *types.Detail {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockLifecycle struct{ user string }
|
||||||
|
|
||||||
|
func (ml *mockLifecycle) Connection() ssh.Conn { return nil }
|
||||||
|
func (ml *mockLifecycle) PortRegistry() port.Port { return nil }
|
||||||
|
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { _ = channel }
|
||||||
|
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { _ = status }
|
||||||
|
func (ml *mockLifecycle) IsActive() bool { return false }
|
||||||
|
func (ml *mockLifecycle) StartedAt() time.Time { return time.Time{} }
|
||||||
|
func (ml *mockLifecycle) Close() error { return nil }
|
||||||
|
func (ml *mockLifecycle) User() string { return ml.user }
|
||||||
|
|
||||||
|
type mockSlug struct{}
|
||||||
|
|
||||||
|
func (ms *mockSlug) Set(slug string) { _ = slug }
|
||||||
|
func (ms *mockSlug) String() string { return "" }
|
||||||
|
|
||||||
|
func TestNewRegistry(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatal("NewRegistry returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Get(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupFunc func(r *registry)
|
||||||
|
key types.SessionKey
|
||||||
|
wantErr error
|
||||||
|
wantResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "session found",
|
||||||
|
setupFunc: func(r *registry) {
|
||||||
|
user := "user1"
|
||||||
|
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: user}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser[user] = map[types.SessionKey]Session{
|
||||||
|
key: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[key] = user
|
||||||
|
},
|
||||||
|
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||||
|
wantErr: nil,
|
||||||
|
wantResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session not found in slugIndex",
|
||||||
|
setupFunc: func(r *registry) {},
|
||||||
|
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
||||||
|
wantErr: ErrSessionNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session not found in byUser",
|
||||||
|
setupFunc: func(r *registry) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
||||||
|
},
|
||||||
|
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||||
|
wantErr: ErrSessionNotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[types.SessionKey]Session),
|
||||||
|
slugIndex: make(map[types.SessionKey]string),
|
||||||
|
mu: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
tt.setupFunc(r)
|
||||||
|
|
||||||
|
session, err := r.Get(tt.key)
|
||||||
|
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("expected error %v, got %v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (session != nil) != tt.wantResult {
|
||||||
|
t.Fatalf("expected session existence to be %v, got %v", tt.wantResult, session != nil)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_GetWithUser(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupFunc func(r *registry)
|
||||||
|
user string
|
||||||
|
key types.SessionKey
|
||||||
|
wantErr error
|
||||||
|
wantResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "session found",
|
||||||
|
setupFunc: func(r *registry) {
|
||||||
|
user := "user1"
|
||||||
|
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: user}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser[user] = map[types.SessionKey]Session{
|
||||||
|
key: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[key] = user
|
||||||
|
},
|
||||||
|
user: "user1",
|
||||||
|
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||||
|
wantErr: nil,
|
||||||
|
wantResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session not found in slugIndex",
|
||||||
|
setupFunc: func(r *registry) {},
|
||||||
|
user: "user1",
|
||||||
|
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
||||||
|
wantErr: ErrSessionNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session not found in byUser",
|
||||||
|
setupFunc: func(r *registry) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
||||||
|
},
|
||||||
|
user: "user1",
|
||||||
|
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||||
|
wantErr: ErrSessionNotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[types.SessionKey]Session),
|
||||||
|
slugIndex: make(map[types.SessionKey]string),
|
||||||
|
mu: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
tt.setupFunc(r)
|
||||||
|
|
||||||
|
session, err := r.GetWithUser(tt.user, tt.key)
|
||||||
|
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("expected error %v, got %v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (session != nil) != tt.wantResult {
|
||||||
|
t.Fatalf("expected session existence to be %v, got %v", tt.wantResult, session != nil)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Update(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
user string
|
||||||
|
setupFunc func(r *registry) (oldKey, newKey types.SessionKey)
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "change slug success",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
oldKey: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[oldKey] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change slug to already used slug",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
oldKey: session,
|
||||||
|
newKey: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[oldKey] = "user1"
|
||||||
|
r.slugIndex[newKey] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: ErrSlugInUse,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change slug to forbidden slug",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
oldKey: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[oldKey] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: ErrForbiddenSlug,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change slug to invalid slug",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
oldKey: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[oldKey] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: ErrInvalidSlug,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change slug but session not found",
|
||||||
|
user: "user2",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
|
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: ErrSessionNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change slug but session is not in the map",
|
||||||
|
user: "user2",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
|
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: ErrSessionNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change slug with same slug",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
oldKey: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[oldKey] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: ErrSlugUnchanged,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp tunnel cannot change slug",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
|
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||||
|
newKey := oldKey
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||||
|
oldKey: session,
|
||||||
|
}
|
||||||
|
r.slugIndex[oldKey] = "user1"
|
||||||
|
|
||||||
|
return oldKey, newKey
|
||||||
|
},
|
||||||
|
wantErr: ErrSlugChangeNotAllowed,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[types.SessionKey]Session),
|
||||||
|
slugIndex: make(map[types.SessionKey]string),
|
||||||
|
mu: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
oldKey, newKey := tt.setupFunc(r)
|
||||||
|
|
||||||
|
err := r.Update(tt.user, oldKey, newKey)
|
||||||
|
if !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Fatalf("expected error %v, got %v", tt.wantErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
if _, ok := r.byUser[tt.user][newKey]; !ok {
|
||||||
|
t.Errorf("newKey not found in registry")
|
||||||
|
}
|
||||||
|
if _, ok := r.byUser[tt.user][oldKey]; ok {
|
||||||
|
t.Errorf("oldKey still exists in registry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Register(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
user string
|
||||||
|
setupFunc func(r *registry) Key
|
||||||
|
wantOK bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "register new key successfully",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) Key {
|
||||||
|
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
return key
|
||||||
|
},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "register already existing key fails",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) Key {
|
||||||
|
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
r.byUser["user1"] = map[Key]Session{key: session}
|
||||||
|
r.slugIndex[key] = "user1"
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
return key
|
||||||
|
},
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "register multiple keys for same user",
|
||||||
|
user: "user1",
|
||||||
|
setupFunc: func(r *registry) Key {
|
||||||
|
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: "user1"}
|
||||||
|
r.mu.Lock()
|
||||||
|
r.byUser["user1"] = map[Key]Session{firstKey: session}
|
||||||
|
r.slugIndex[firstKey] = "user1"
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP}
|
||||||
|
},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[Key]Session),
|
||||||
|
slugIndex: make(map[Key]string),
|
||||||
|
mu: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
key := tt.setupFunc(r)
|
||||||
|
session := &mockSession{user: tt.user}
|
||||||
|
|
||||||
|
ok := r.Register(key, session)
|
||||||
|
if ok != tt.wantOK {
|
||||||
|
t.Fatalf("expected success %v, got %v", tt.wantOK, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
if r.byUser[tt.user][key] != session {
|
||||||
|
t.Errorf("session not stored in byUser")
|
||||||
|
}
|
||||||
|
if r.slugIndex[key] != tt.user {
|
||||||
|
t.Errorf("slugIndex not updated")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_GetAllSessionFromUser(t *testing.T) {
|
||||||
|
t.Run("user has no sessions", func(t *testing.T) {
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[Key]Session),
|
||||||
|
slugIndex: make(map[Key]string),
|
||||||
|
}
|
||||||
|
sessions := r.GetAllSessionFromUser("user1")
|
||||||
|
if len(sessions) != 0 {
|
||||||
|
t.Errorf("expected 0 sessions, got %d", len(sessions))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("user has multiple sessions", func(t *testing.T) {
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[Key]Session),
|
||||||
|
slugIndex: make(map[Key]string),
|
||||||
|
mu: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
user := "user1"
|
||||||
|
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||||
|
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
||||||
|
session1 := &mockSession{user: user}
|
||||||
|
session2 := &mockSession{user: user}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
r.byUser[user] = map[Key]Session{
|
||||||
|
key1: session1,
|
||||||
|
key2: session2,
|
||||||
|
}
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
sessions := r.GetAllSessionFromUser(user)
|
||||||
|
if len(sessions) != 2 {
|
||||||
|
t.Errorf("expected 2 sessions, got %d", len(sessions))
|
||||||
|
}
|
||||||
|
|
||||||
|
found := map[Session]bool{}
|
||||||
|
for _, s := range sessions {
|
||||||
|
found[s] = true
|
||||||
|
}
|
||||||
|
if !found[session1] || !found[session2] {
|
||||||
|
t.Errorf("returned sessions do not match expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Remove(t *testing.T) {
|
||||||
|
t.Run("remove existing key", func(t *testing.T) {
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[Key]Session),
|
||||||
|
slugIndex: make(map[Key]string),
|
||||||
|
mu: sync.RWMutex{},
|
||||||
|
}
|
||||||
|
|
||||||
|
user := "user1"
|
||||||
|
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||||
|
session := &mockSession{user: user}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
r.byUser[user] = map[Key]Session{key: session}
|
||||||
|
r.slugIndex[key] = user
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
r.Remove(key)
|
||||||
|
|
||||||
|
if _, ok := r.byUser[user][key]; ok {
|
||||||
|
t.Errorf("expected key to be removed from byUser")
|
||||||
|
}
|
||||||
|
if _, ok := r.slugIndex[key]; ok {
|
||||||
|
t.Errorf("expected key to be removed from slugIndex")
|
||||||
|
}
|
||||||
|
if _, ok := r.byUser[user]; ok {
|
||||||
|
t.Errorf("expected user to be removed from byUser map")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("remove non-existing key", func(t *testing.T) {
|
||||||
|
r := ®istry{
|
||||||
|
byUser: make(map[string]map[Key]Session),
|
||||||
|
slugIndex: make(map[Key]string),
|
||||||
|
}
|
||||||
|
r.Remove(types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidSlug(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
slug string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"abc", true},
|
||||||
|
{"abc-123", true},
|
||||||
|
{"a", false},
|
||||||
|
{"verybigdihsixsevenlabubu", false},
|
||||||
|
{"-iamsigma", false},
|
||||||
|
{"ligma-", false},
|
||||||
|
{"invalid$", false},
|
||||||
|
{"valid-slug1", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.slug, func(t *testing.T) {
|
||||||
|
got := isValidSlug(tt.slug)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidSlugChar(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
char byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{'a', true},
|
||||||
|
{'z', true},
|
||||||
|
{'0', true},
|
||||||
|
{'9', true},
|
||||||
|
{'-', true},
|
||||||
|
{'A', false},
|
||||||
|
{'$', false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(string(tt.char), func(t *testing.T) {
|
||||||
|
got := isValidSlugChar(tt.char)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsForbiddenSlug(t *testing.T) {
|
||||||
|
forbiddenSlugs = map[string]struct{}{
|
||||||
|
"admin": {},
|
||||||
|
"root": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
slug string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"admin", true},
|
||||||
|
{"root", true},
|
||||||
|
{"user", false},
|
||||||
|
{"guest", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.slug, func(t *testing.T) {
|
||||||
|
got := isForbiddenSlug(tt.slug)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -34,7 +34,7 @@ func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTL
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
|
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
|
||||||
_, err := conn.Write([]byte(fmt.Sprintf("TunnelTypeHTTP/1.1 %d Moved Permanently\r\n", status) +
|
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
|
||||||
fmt.Sprintf("Location: %s", location) +
|
fmt.Sprintf("Location: %s", location) +
|
||||||
"Content-Length: 0\r\n" +
|
"Content-Length: 0\r\n" +
|
||||||
"Connection: close\r\n" +
|
"Connection: close\r\n" +
|
||||||
@@ -46,7 +46,7 @@ func (hh *httpHandler) redirect(conn net.Conn, status int, location string) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hh *httpHandler) badRequest(conn net.Conn) error {
|
func (hh *httpHandler) badRequest(conn net.Conn) error {
|
||||||
if _, err := conn.Write([]byte("TunnelTypeHTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
|
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -87,7 +87,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
|||||||
defer func(hw stream.HTTP) {
|
defer func(hw stream.HTTP) {
|
||||||
err = hw.Close()
|
err = hw.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error closing TunnelTypeHTTP stream: %v", err)
|
log.Printf("Error closing HTTP stream: %v", err)
|
||||||
}
|
}
|
||||||
}(hw)
|
}(hw)
|
||||||
hh.forwardRequest(hw, reqhf, sshSession)
|
hh.forwardRequest(hw, reqhf, sshSession)
|
||||||
@@ -118,7 +118,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := conn.Write([]byte(
|
_, err := conn.Write([]byte(
|
||||||
"TunnelTypeHTTP/1.1 200 OK\r\n" +
|
"HTTP/1.1 200 OK\r\n" +
|
||||||
"Content-Length: 0\r\n" +
|
"Content-Length: 0\r\n" +
|
||||||
"Connection: close\r\n" +
|
"Connection: close\r\n" +
|
||||||
"Access-Control-Allow-Origin: *\r\n" +
|
"Access-Control-Allow-Origin: *\r\n" +
|
||||||
|
|||||||
@@ -145,9 +145,9 @@ func (m *model) slugView() string {
|
|||||||
|
|
||||||
var warningText string
|
var warningText string
|
||||||
if isVeryCompact {
|
if isVeryCompact {
|
||||||
warningText = "⚠️ TunnelTypeTCP tunnels don't support custom subdomains."
|
warningText = "⚠️ TCP tunnels don't support custom subdomains."
|
||||||
} else {
|
} else {
|
||||||
warningText = "⚠️ TunnelTypeTCP tunnels cannot have custom subdomains. Only TunnelTypeHTTP/HTTPS tunnels support subdomain customization."
|
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
|
||||||
}
|
}
|
||||||
b.WriteString(warningBoxStyle.Render(warningText))
|
b.WriteString(warningBoxStyle.Render(warningText))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|||||||
+1
-1
@@ -160,7 +160,7 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleMissingForwardRequest() error {
|
func (s *session) handleMissingForwardRequest() error {
|
||||||
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain, s.config.SSHPort))
|
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
+1
-1
@@ -44,7 +44,7 @@ type Detail struct {
|
|||||||
StartedAt time.Time `json:"started_at,omitempty"`
|
StartedAt time.Time `json:"started_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var BadGatewayResponse = []byte("TunnelTypeHTTP/1.1 502 Bad Gateway\r\n" +
|
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
||||||
"Content-Length: 11\r\n" +
|
"Content-Length: 11\r\n" +
|
||||||
"Content-Type: text/plain\r\n\r\n" +
|
"Content-Type: text/plain\r\n\r\n" +
|
||||||
"Bad Gateway")
|
"Bad Gateway")
|
||||||
|
|||||||
Reference in New Issue
Block a user