25 Commits

Author SHA1 Message Date
aafea49975 feat: integrate gRPC, session refactor, SSH headless support, and bug fixes
Docker Build and Push / build-and-push-tags (push) Successful in 11m34s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- gRPC integration: slug edit handling, get sessions by user, and session requests from gRPC server
- Refactor gRPC client: simplify processEventStream and handle authenticated user info
- Session management improvements: use session key for registry, forwarder session termination, inject SessionRegistry interface
- SSH enhancements: add headless mode support for SSH -N connections
- Bug fixes:
  - prevent subdomain changes to already-in-use subdomains
  - fix startup order and environment variable keys
  - atomic ClaimPort() to prevent race conditions
- Refactors:
  - consolidate error handling
  - replace Get/Set patterns with idiomatic Go interfaces
  - change enums from string to int
- CI cleanup: remove renovate bot

Reviewed-on: #65
2026-01-14 10:16:43 +00:00
dbdf8094fa refactor: replace Get/Set patterns with idiomatic Go interfaces
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m4s
- rename constructors to New
- remove Get/Set-style accessors
- replace string-based enums with iota-backed types
2026-01-14 16:54:10 +07:00
ae3ed52d16 fix(port): add atomic ClaimPort() to prevent race condition
- Replace GetPortStatus/SetPortStatus calls with atomic ClaimPort() operation.
- Fixed a logic error when handling headless tunneling.
2026-01-14 16:51:50 +07:00
fb638636bf refactor: consolidate error handling with fail() function in session handlers
- Replace repetitive error handling code with fail() function in HandleGlobalRequest
- Standardize error response pattern across all handler methods
- Improve code maintainability and reduce duplication
2026-01-14 16:51:50 +07:00
da29df85b7 feat: add headless mode support for SSH -N connections
- use s.lifecycle.GetConnection().Wait() to block until SSH connection closes
- Prevent premature session closure in headless mode

In headless mode (ssh -N), there's no channel interaction to block on,
so the session would immediately return and close. Now blocking on
conn.Wait() keeps the session alive until the client disconnects.
2026-01-14 16:51:50 +07:00
8b0e08c629 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 2026-01-14 16:51:50 +07:00
f0804d6946 ci: remove renovate 2026-01-14 16:51:50 +07:00
09e526cd1e feat: add authenticated user info and restructure handleConnection
- Display authenticated username in welcome page information box
- Refactor handleConnection function for better structure and clarity
2026-01-14 16:51:50 +07:00
887ebf78b1 refactor(grpc/client): simplify processEventStream with per-event handlers
- Extract eventHandlers dispatch table
- Add per-event handlers: handleSlugChange, handleGetSessions, handleTerminateSession
- Introduce sendNode helper to centralize send/error handling and preserve connection-error propagation
- Add protoToTunnelType for tunnel-type validation
- Map unknown proto.TunnelType to types.UNKNOWN in protoToTunnelType and return a descriptive error
- Reduce boilerplate and improve readability of processEventStream
2026-01-14 16:51:50 +07:00
bef7a49f88 feat: implement forwarder session termination 2026-01-14 16:51:50 +07:00
17633b4e3c refactor: inject SessionRegistry interface instead of individual functions 2026-01-14 16:51:50 +07:00
f25d61d1d1 update: proto file to v1.3.0 2026-01-14 16:51:50 +07:00
8782b77b74 feat(session): use session key for registry 2026-01-14 16:51:50 +07:00
fc3cd886db fix: use correct environment variable key 2026-01-14 16:51:50 +07:00
b0da57db0d fix: startup order 2026-01-14 16:51:50 +07:00
0bd6eeadf3 feat: implement sessions request from grpc server 2026-01-14 16:51:50 +07:00
449f546e04 feat: implement sessions request from grpc server 2026-01-14 16:51:50 +07:00
4644420eee feat: implement get sessions by user 2026-01-14 16:51:50 +07:00
c9bf9e62bd feat(grpc): integrate slug edit handling 2026-01-14 16:51:50 +07:00
57d2136377 WIP: gRPC integration, initial implementation 2026-01-14 16:51:47 +07:00
8a34aaba80 WIP: gRPC integration, initial implementation 2026-01-14 16:51:35 +07:00
ff995a929e revert 01ddc76f7e
revert Merge pull request 'fix(deps): update module github.com/caddyserver/certmagic to v0.25.1' (#58) from renovate/github.com-caddyserver-certmagic-0.x into main
2026-01-14 16:51:35 +07:00
32ac9c1749 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1
# Conflicts:
#	go.mod
2026-01-14 16:51:30 +07:00
e051a5b742 Merge pull request 'fix(deps): update module golang.org/x/crypto to v0.47.0' (#64) from renovate/golang.org-x-crypto-0.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m51s
renovate / renovate (push) Successful in 55s
2026-01-12 18:20:57 +00:00
d35228759c fix(deps): update module golang.org/x/crypto to v0.47.0 2026-01-12 18:20:53 +00:00
18 changed files with 742 additions and 596 deletions
-21
View File
@@ -1,21 +0,0 @@
name: renovate
on:
schedule:
- cron: "0 0 * * *"
push:
branches:
- staging
jobs:
renovate:
runs-on: ubuntu-latest
container: git.fossy.my.id/renovate-clanker/renovate:latest
steps:
- uses: actions/checkout@v6
- run: renovate
env:
RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js
LOG_LEVEL: "debug"
RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }}
GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }}
+3 -3
View File
@@ -3,8 +3,8 @@ module tunnel_pls
go 1.25.5 go 1.25.5
require ( require (
git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0 git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0
github.com/caddyserver/certmagic v0.25.0 github.com/caddyserver/certmagic v0.25.1
github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/lipgloss v1.1.0
@@ -19,7 +19,7 @@ require (
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect github.com/caddyserver/zerossl v0.1.4 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.3 // indirect github.com/charmbracelet/x/ansi v0.11.3 // indirect
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
+10 -2
View File
@@ -1,5 +1,9 @@
git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0 h1:BS1dJU3wa2ILgTGwkV95Knle0il0OQtErGqyb6xV7SU= git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 h1:RhcBKUG41/om4jgN+iF/vlY/RojTeX1QhBa4p4428ec=
git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 h1:tpJSKjaSmV+vxxbVx6qnStjxFVXjj2M0rygWXxLb99o=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
@@ -8,8 +12,12 @@ github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWp
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic= github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic=
github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA= github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA=
github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY=
github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY=
github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw=
github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
+124 -80
View File
@@ -9,6 +9,7 @@ import (
"log" "log"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/types"
"tunnel_pls/session" "tunnel_pls/session"
@@ -39,9 +40,9 @@ type Client struct {
conn *grpc.ClientConn conn *grpc.ClientConn
config *GrpcConfig config *GrpcConfig
sessionRegistry session.Registry sessionRegistry session.Registry
slugService proto.SlugChangeClient
eventService proto.EventServiceClient eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient authorizeConnectionService proto.UserServiceClient
closing bool
} }
func DefaultConfig() *GrpcConfig { func DefaultConfig() *GrpcConfig {
@@ -113,14 +114,12 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error)
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err) return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err)
} }
slugService := proto.NewSlugChangeClient(conn)
eventService := proto.NewEventServiceClient(conn) eventService := proto.NewEventServiceClient(conn)
authorizeConnectionService := proto.NewUserServiceClient(conn) authorizeConnectionService := proto.NewUserServiceClient(conn)
return &Client{ return &Client{
conn: conn, conn: conn,
config: config, config: config,
slugService: slugService,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
eventService: eventService, eventService: eventService,
authorizeConnectionService: authorizeConnectionService, authorizeConnectionService: authorizeConnectionService,
@@ -155,16 +154,17 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
for { for {
subscribe, err := c.eventService.Subscribe(ctx) subscribe, err := c.eventService.Subscribe(ctx)
if err != nil { if err != nil {
if !isConnectionError(err) { if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
return err return err
} }
if status.Code(err) == codes.Unauthenticated { if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
return err return err
} }
if err := wait(); err != nil { if err = wait(); err != nil {
return err return err
} }
growBackoff() growBackoff()
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
continue continue
} }
@@ -180,8 +180,8 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
if err != nil { if err != nil {
log.Println("Authentication failed to send to gRPC server:", err) log.Println("Authentication failed to send to gRPC server:", err)
if isConnectionError(err) { if c.isConnectionError(err) {
if err := wait(); err != nil { if err = wait(); err != nil {
return err return err
} }
growBackoff() growBackoff()
@@ -193,9 +193,12 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
backoff = baseBackoff backoff = baseBackoff
if err = c.processEventStream(subscribe); err != nil { if err = c.processEventStream(subscribe); err != nil {
if isConnectionError(err) { if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
return err
}
if c.isConnectionError(err) {
log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
if err := wait(); err != nil { if err = wait(); err != nil {
return err return err
} }
growBackoff() growBackoff()
@@ -207,89 +210,76 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
} }
func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error { func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error {
handlers := c.eventHandlers(subscribe)
for { for {
recv, err := subscribe.Recv() recv, err := subscribe.Recv()
if err != nil { if err != nil {
if isConnectionError(err) {
log.Printf("connection error receiving from gRPC server: %v", err)
return err return err
} }
if status.Code(err) == codes.Unauthenticated {
log.Printf("Authentication failed: %v", err) handler, ok := handlers[recv.GetType()]
if !ok {
log.Printf("Unknown event type received: %v", recv.GetType())
continue
}
if err = handler(recv); err != nil {
return err return err
} }
log.Printf("non-connection receive error from gRPC server: %v", err)
continue
} }
switch recv.GetType() { }
case proto.EventType_SLUG_CHANGE:
oldSlug := recv.GetSlugEvent().GetOld() func (c *Client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error {
newSlug := recv.GetSlugEvent().GetNew() return map[proto.EventType]func(*proto.Events) error{
sess, err := c.sessionRegistry.Get(oldSlug) proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) },
proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) },
proto.EventType_TERMINATE_SESSION: func(evt *proto.Events) error { return c.handleTerminateSession(subscribe, evt) },
}
}
func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
slugEvent := evt.GetSlugEvent()
user := slugEvent.GetUser()
oldSlug := slugEvent.GetOld()
newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP})
if err != nil { if err != nil {
errSend := subscribe.Send(&proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{ SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
Success: false,
Message: err.Error(),
}, },
}, }, "slug change failure response")
})
if errSend != nil {
if isConnectionError(errSend) {
log.Printf("connection error sending slug change failure: %v", errSend)
return errSend
} }
log.Printf("non-connection send error for slug change failure: %v", errSend)
} if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil {
continue return c.sendNode(subscribe, &proto.Node{
}
err = c.sessionRegistry.Update(oldSlug, newSlug)
if err != nil {
errSend := subscribe.Send(&proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{ SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
Success: false,
Message: err.Error(),
}, },
}, }, "slug change failure response")
})
if errSend != nil {
if isConnectionError(errSend) {
log.Printf("connection error sending slug change failure: %v", errSend)
return errSend
} }
log.Printf("non-connection send error for slug change failure: %v", errSend)
} userSession.Interaction().Redraw()
continue return c.sendNode(subscribe, &proto.Node{
}
sess.GetInteraction().Redraw()
err = subscribe.Send(&proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{ SlugEventResponse: &proto.SlugChangeEventResponse{Success: true, Message: ""},
Success: true,
Message: "",
}, },
}, }, "slug change success response")
})
if err != nil {
if isConnectionError(err) {
log.Printf("connection error sending slug change success: %v", err)
return err
} }
log.Printf("non-connection send error for slug change success: %v", err)
continue func (c *Client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
} sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity())
case proto.EventType_GET_SESSIONS:
sessions := c.sessionRegistry.GetAllSessionFromUser(recv.GetGetSessionsEvent().GetIdentity())
var details []*proto.Detail var details []*proto.Detail
for _, ses := range sessions { for _, ses := range sessions {
detail := ses.Detail() detail := ses.Detail()
details = append(details, &proto.Detail{ details = append(details, &proto.Detail{
Node: config.Getenv("domain", "localhost"), Node: config.Getenv("DOMAIN", "localhost"),
ForwardingType: detail.ForwardingType, ForwardingType: detail.ForwardingType,
Slug: detail.Slug, Slug: detail.Slug,
UserId: detail.UserID, UserId: detail.UserID,
@@ -297,25 +287,75 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
StartedAt: timestamppb.New(detail.StartedAt), StartedAt: timestamppb.New(detail.StartedAt),
}) })
} }
err = subscribe.Send(&proto.Node{
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_GET_SESSIONS, Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Node_GetSessionsEvent{ Payload: &proto.Node_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsResponse{ GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
Details: details,
}, },
}, }, "send get sessions response")
}) }
func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
terminate := evt.GetTerminateSessionEvent()
user := terminate.GetUser()
slug := terminate.GetSlug()
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
if err != nil { if err != nil {
if isConnectionError(err) { return c.sendNode(subscribe, &proto.Node{
log.Printf("connection error sending sessions success: %v", err) Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session invalid tunnel type")
}
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
if err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session fetch failed")
}
if err = userSession.Lifecycle().Close(); err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session close failed")
}
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: true, Message: ""},
},
}, "terminate session success response")
}
func (c *Client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
if err := subscribe.Send(node); err != nil {
if c.isConnectionError(err) {
return err return err
} }
log.Printf("non-connection send error for sessions success: %v", err) log.Printf("%s: %v", context, err)
continue
} }
return nil
}
func (c *Client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t {
case proto.TunnelType_HTTP:
return types.HTTP, nil
case proto.TunnelType_TCP:
return types.TCP, nil
default: default:
log.Printf("Unknown event type received: %v", recv.GetType()) return types.UNKNOWN, fmt.Errorf("unknown tunnel type received")
}
} }
} }
@@ -338,6 +378,7 @@ func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bo
func (c *Client) Close() error { func (c *Client) Close() error {
if c.conn != nil { if c.conn != nil {
log.Printf("Closing gRPC connection to %s", c.config.Address) log.Printf("Closing gRPC connection to %s", c.config.Address)
c.closing = true
return c.conn.Close() return c.conn.Close()
} }
return nil return nil
@@ -361,7 +402,10 @@ func (c *Client) GetConfig() *GrpcConfig {
return c.config return c.config
} }
func isConnectionError(err error) bool { func (c *Client) isConnectionError(err error) bool {
if c.closing {
return false
}
if err == nil { if err == nil {
return false return false
} }
+16 -6
View File
@@ -13,7 +13,7 @@ type Manager interface {
AddPortRange(startPort, endPort uint16) error AddPortRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool) GetUnassignedPort() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error SetPortStatus(port uint16, assigned bool) error
GetPortStatus(port uint16) (bool, bool) ClaimPort(port uint16) (claimed bool)
} }
type manager struct { type manager struct {
@@ -74,7 +74,6 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) {
for _, port := range pm.sortedPorts { for _, port := range pm.sortedPorts {
if !pm.ports[port] { if !pm.ports[port] {
pm.ports[port] = true
return port, true return port, true
} }
} }
@@ -89,10 +88,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil return nil
} }
func (pm *manager) GetPortStatus(port uint16) (bool, bool) { func (pm *manager) ClaimPort(port uint16) (claimed bool) {
pm.mu.RLock() pm.mu.Lock()
defer pm.mu.RUnlock() defer pm.mu.Unlock()
status, exists := pm.ports[port] status, exists := pm.ports[port]
return status, exists
if exists && status {
return false
}
if !exists {
pm.ports[port] = true
return true
}
pm.ports[port] = true
return true
} }
+38 -24
View File
@@ -7,7 +7,9 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal"
"strings" "strings"
"syscall"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
@@ -68,10 +70,14 @@ func main() {
sshConfig.AddHostKey(private) sshConfig.AddHostKey(private)
sessionRegistry := session.NewRegistry() sessionRegistry := session.NewRegistry()
var grpcClient *client.Client ctx, cancel := context.WithCancel(context.Background())
var cancel context.CancelFunc = func() {} defer cancel()
var ctx context.Context = context.Background()
errChan := make(chan error, 2)
shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient *client.Client
if isNodeMode { if isNodeMode {
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost") grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
grpcPort := config.Getenv("GRPC_PORT", "8080") grpcPort := config.Getenv("GRPC_PORT", "8080")
@@ -79,10 +85,9 @@ func main() {
nodeToken := config.Getenv("NODE_TOKEN", "") nodeToken := config.Getenv("NODE_TOKEN", "")
if nodeToken == "" { if nodeToken == "" {
log.Fatalf("NODE_TOKEN is required in node mode") log.Fatalf("NODE_TOKEN is required in node mode")
return
} }
grpcClient, err = client.New(&client.GrpcConfig{ c, err := client.New(&client.GrpcConfig{
Address: grpcAddr, Address: grpcAddr,
UseTLS: false, UseTLS: false,
InsecureSkipVerify: false, InsecureSkipVerify: false,
@@ -91,37 +96,46 @@ func main() {
MaxRetries: 3, MaxRetries: 3,
}, sessionRegistry) }, sessionRegistry)
if err != nil { if err != nil {
return log.Fatalf("failed to create grpc client: %v", err)
} }
defer func(grpcClient *client.Client) { grpcClient = c
err := grpcClient.Close()
if err != nil {
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
if err := grpcClient.CheckServerHealth(healthCtx); err != nil {
healthCancel()
log.Fatalf("gRPC health check failed: %v", err)
} }
}(grpcClient) healthCancel()
ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
err = grpcClient.CheckServerHealth(ctx)
if err != nil {
log.Fatalf("gRPC health check failed: %s", err)
return
}
cancel()
ctx, cancel = context.WithCancel(context.Background())
go func() { go func() {
identity := config.Getenv("DOMAIN", "localhost") identity := config.Getenv("DOMAIN", "localhost")
err = grpcClient.SubscribeEvents(ctx, identity, nodeToken) if err := grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
if err != nil { errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
return
} }
}() }()
} }
go func() {
app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient) app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient)
if err != nil { if err != nil {
log.Fatalf("Failed to start server: %s", err) errChan <- fmt.Errorf("failed to start server: %s", err)
return
} }
app.Start() app.Start()
cancel() }()
select {
case err := <-errChan:
log.Printf("error happen : %s", err)
case sig := <-shutdownChan:
log.Printf("received signal %s, shutting down", sig)
}
cancel()
if grpcClient != nil {
if err := grpcClient.Close(); err != nil {
log.Printf("failed to close grpc conn : %s", err)
}
}
} }
-8
View File
@@ -1,8 +0,0 @@
module.exports = {
"endpoint": "https://git.fossy.my.id/api/v1",
"gitAuthor": "Renovate-Clanker <renovate-bot@fossy.my.id>",
"platform": "gitea",
"onboardingConfigFileName": "renovate.json",
"autodiscover": true,
"optimizeForDisabled": true,
};
+11 -7
View File
@@ -13,6 +13,7 @@ import (
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/types"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -313,7 +314,10 @@ func (hs *httpServer) handler(conn net.Conn) {
return return
} }
sshSession, err := hs.sessionRegistry.Get(slug) sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil { if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
@@ -331,8 +335,8 @@ func (hs *httpServer) handler(conn net.Conn) {
return return
} }
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) { func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
@@ -342,7 +346,7 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -353,14 +357,14 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
case result := <-resultChan: case result := <-resultChan:
if result.err != nil { if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
channel = result.channel channel = result.channel
reqs = result.reqs reqs = result.reqs
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel") log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
@@ -386,6 +390,6 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
return return
} }
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return return
} }
+5 -1
View File
@@ -9,6 +9,7 @@ import (
"net" "net"
"strings" "strings"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/types"
) )
func (hs *httpServer) ListenAndServeTLS() error { func (hs *httpServer) ListenAndServeTLS() error {
@@ -89,7 +90,10 @@ func (hs *httpServer) handlerTLS(conn net.Conn) {
return return
} }
sshSession, err := hs.sessionRegistry.Get(slug) sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil { if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
+13 -11
View File
@@ -2,9 +2,11 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
"time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/session" "tunnel_pls/session"
@@ -64,31 +66,31 @@ func (s *Server) Start() {
func (s *Server) handleConnection(conn net.Conn) { func (s *Server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
defer func(sshConn *ssh.ServerConn) {
err = sshConn.Close()
if err != nil {
log.Printf("failed to close SSH server: %v", err)
}
}(sshConn)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err := conn.Close() err = conn.Close()
if err != nil { if err != nil {
log.Printf("failed to close SSH connection: %v", err) log.Printf("failed to close SSH connection: %v", err)
return return
} }
return return
} }
ctx := context.Background()
log.Println("SSH connection established:", sshConn.User()) defer func(sshConn *ssh.ServerConn) {
err = sshConn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("failed to close SSH server: %v", err)
}
}(sshConn)
user := "UNAUTHORIZED" user := "UNAUTHORIZED"
if s.grpcClient != nil { if s.grpcClient != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
_, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User()) _, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
user = u user = u
cancel()
} }
log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user)
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
+30 -30
View File
@@ -30,50 +30,50 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
return io.CopyBuffer(dst, src, buf) return io.CopyBuffer(dst, src, buf)
} }
type Forwarder struct { type forwarder struct {
listener net.Listener listener net.Listener
tunnelType types.TunnelType tunnelType types.TunnelType
forwardedPort uint16 forwardedPort uint16
slugManager slug.Manager slug slug.Slug
lifecycle Lifecycle lifecycle Lifecycle
} }
func NewForwarder(slugManager slug.Manager) *Forwarder { func New(slug slug.Slug) Forwarder {
return &Forwarder{ return &forwarder{
listener: nil, listener: nil,
tunnelType: "", tunnelType: types.UNKNOWN,
forwardedPort: 0, forwardedPort: 0,
slugManager: slugManager, slug: slug,
lifecycle: nil, lifecycle: nil,
} }
} }
type Lifecycle interface { type Lifecycle interface {
GetConnection() ssh.Conn Connection() ssh.Conn
} }
type ForwardingController interface { type Forwarder interface {
AcceptTCPConnections()
SetType(tunnelType types.TunnelType) SetType(tunnelType types.TunnelType)
GetTunnelType() types.TunnelType SetLifecycle(lifecycle Lifecycle)
GetForwardedPort() uint16
SetForwardedPort(port uint16) SetForwardedPort(port uint16)
SetListener(listener net.Listener) SetListener(listener net.Listener)
GetListener() net.Listener Listener() net.Listener
Close() error TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
AcceptTCPConnections()
Close() error
} }
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *forwarder) SetLifecycle(lifecycle Lifecycle) {
f.lifecycle = lifecycle f.lifecycle = lifecycle
} }
func (f *Forwarder) AcceptTCPConnections() { func (f *forwarder) AcceptTCPConnections() {
for { for {
conn, err := f.GetListener().Accept() conn, err := f.Listener().Accept()
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return return
@@ -100,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() {
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -130,7 +130,7 @@ func (f *Forwarder) AcceptTCPConnections() {
} }
} }
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func() { defer func() {
_, err := io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil { if err != nil {
@@ -174,31 +174,31 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
wg.Wait() wg.Wait()
} }
func (f *Forwarder) SetType(tunnelType types.TunnelType) { func (f *forwarder) SetType(tunnelType types.TunnelType) {
f.tunnelType = tunnelType f.tunnelType = tunnelType
} }
func (f *Forwarder) GetTunnelType() types.TunnelType { func (f *forwarder) TunnelType() types.TunnelType {
return f.tunnelType return f.tunnelType
} }
func (f *Forwarder) GetForwardedPort() uint16 { func (f *forwarder) ForwardedPort() uint16 {
return f.forwardedPort return f.forwardedPort
} }
func (f *Forwarder) SetForwardedPort(port uint16) { func (f *forwarder) SetForwardedPort(port uint16) {
f.forwardedPort = port f.forwardedPort = port
} }
func (f *Forwarder) SetListener(listener net.Listener) { func (f *forwarder) SetListener(listener net.Listener) {
f.listener = listener f.listener = listener
} }
func (f *Forwarder) GetListener() net.Listener { func (f *forwarder) Listener() net.Listener {
return f.listener return f.listener
} }
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse) _, err := dst.Write(types.BadGatewayResponse)
if err != nil { if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err) log.Printf("failed to write Bad Gateway response: %v", err)
@@ -206,20 +206,20 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
} }
} }
func (f *Forwarder) Close() error { func (f *forwarder) Close() error {
if f.GetListener() != nil { if f.Listener() != nil {
return f.listener.Close() return f.listener.Close()
} }
return nil return nil
} }
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
var buf bytes.Buffer var buf bytes.Buffer
host, originPort := parseAddr(origin.String()) host, originPort := parseAddr(origin.String())
writeSSHString(&buf, "localhost") writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort())) err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
if err != nil { if err != nil {
log.Printf("Failed to write string to buffer: %v", err) log.Printf("Failed to write string to buffer: %v", err)
return nil return nil
+94 -133
View File
@@ -15,7 +15,7 @@ import (
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest { for req := range GlobalRequest {
switch req.Type { switch req.Type {
case "shell", "pty-req": case "shell", "pty-req":
@@ -56,211 +56,172 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
} }
} }
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { func (s *session) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected") log.Println("Port forwarding request detected")
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
reader := bytes.NewReader(req.Payload) reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader) addr, err := readSSHString(reader)
if err != nil { if err != nil {
log.Println("Failed to read address from payload:", err) fail(fmt.Sprintf("Failed to read address from payload: %v", err))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
var rawPortToBind uint32 var rawPortToBind uint32
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
log.Println("Failed to read port from payload:", err) fail(fmt.Sprintf("Failed to read port from payload: %v", err))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
if rawPortToBind > 65535 { if rawPortToBind > 65535 {
log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind) fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
portToBind := uint16(rawPortToBind) portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) { if isBlockedPort(portToBind) {
log.Printf("Port %d is blocked or restricted", portToBind) fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind))
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
if portToBind == 80 || portToBind == 443 { switch portToBind {
case 80, 443:
s.HandleHTTPForward(req, portToBind) s.HandleHTTPForward(req, portToBind)
return default:
}
if portToBind == 0 {
unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign
if !success {
log.Println("No available port")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
log.Printf("Port %d is already in use or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
err = portUtil.Default.SetPortStatus(portToBind, true)
if err != nil {
log.Println("Failed to set port status:", err)
return
}
s.HandleTCPForward(req, addr, portToBind) s.HandleTCPForward(req, addr, portToBind)
} }
}
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
slug := random.GenerateRandomString(20) fail := func(msg string, key *types.SessionKey) {
log.Println(msg)
if !s.registry.Register(slug, s) { if key != nil {
log.Printf("Failed to register client with slug: %s", slug) s.registry.Remove(*key)
err := req.Reply(false, nil) }
if err != nil { if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
} }
}
slug := random.GenerateRandomString(20)
key := types.SessionKey{Id: slug, Type: types.HTTP}
if !s.registry.Register(key, s) {
fail(fmt.Sprintf("Failed to register client with slug: %s", slug), nil)
return return
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
log.Println("Failed to write port to buffer:", err) fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key)
s.registry.Remove(slug)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
log.Printf("HTTP forwarding approved on port: %d", portToBind) log.Printf("HTTP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) fail(fmt.Sprintf("Failed to reply to request: %v", err), &key)
s.registry.Remove(slug)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
s.forwarder.SetType(types.HTTP) s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(slug) s.slug.Set(slug)
s.lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
s.interaction.Start()
} }
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
log.Printf("Requested forwarding on %s:%d", addr, portToBind) fail := func(msg string) {
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) log.Println(msg)
if err != nil { if err := req.Reply(false, nil); err != nil {
log.Printf("Port %d is already in use or restricted", portToBind)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.lifecycle.Close() if err := s.lifecycle.Close(); err != nil {
if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
}
cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if port != 0 {
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
}
if listener != nil {
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %v", closeErr)
}
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
}
if portToBind == 0 {
unassigned, ok := portUtil.Default.GetUnassignedPort()
if !ok {
fail("No available port")
return
}
portToBind = unassigned
}
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
return
}
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil)
return
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
if !s.registry.Register(key, s) {
cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil)
return return
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
log.Println("Failed to write port to buffer:", err) cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return return
} }
log.Printf("TCP forwarding approved on port: %d", portToBind) log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return return
} }
s.forwarder.SetType(types.TCP) s.forwarder.SetType(types.TCP)
s.forwarder.SetListener(listener) s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections() go s.forwarder.AcceptTCPConnections()
s.interaction.Start()
} }
func readSSHString(reader *bytes.Reader) (string, error) { func readSSHString(reader *bytes.Reader) (string, error) {
+82 -39
View File
@@ -23,35 +23,59 @@ import (
type Lifecycle interface { type Lifecycle interface {
Close() error Close() error
User() string
} }
type Controller interface { type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error
}
type Interaction interface {
Mode() types.Mode
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
SetSlugModificator(func(oldSlug, newSlug string) error) SetSessionRegistry(registry SessionRegistry)
Start() SetMode(m types.Mode)
SetWH(w, h int) SetWH(w, h int)
Start()
Redraw() Redraw()
Send(message string) error
} }
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType TunnelType() types.TunnelType
GetForwardedPort() uint16 ForwardedPort() uint16
} }
type Interaction struct { type interaction struct {
channel ssh.Channel channel ssh.Channel
slugManager slug.Manager slug slug.Slug
forwarder Forwarder forwarder Forwarder
lifecycle Lifecycle lifecycle Lifecycle
updateClientSlug func(oldSlug, newSlug string) error sessionRegistry SessionRegistry
program *tea.Program program *tea.Program
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
mode types.Mode
} }
func (i *Interaction) SetWH(w, h int) { func (i *interaction) SetMode(m types.Mode) {
i.mode = m
}
func (i *interaction) Mode() types.Mode {
return i.mode
}
func (i *interaction) Send(message string) error {
if i.channel != nil {
_, err := i.channel.Write([]byte(message))
return err
}
return nil
}
func (i *interaction) SetWH(w, h int) {
if i.program != nil { if i.program != nil {
i.program.Send(tea.WindowSizeMsg{ i.program.Send(tea.WindowSizeMsg{
Width: w, Width: w,
@@ -79,14 +103,14 @@ type model struct {
commandList list.Model commandList list.Model
slugInput textinput.Model slugInput textinput.Model
slugError string slugError string
interaction *Interaction interaction *interaction
width int width int
height int height int
} }
func (m *model) getTunnelURL() string { func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP { if m.tunnelType == types.HTTP {
return buildURL(m.protocol, m.interaction.slugManager.Get(), m.domain) return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
} }
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
} }
@@ -99,33 +123,33 @@ type keymap struct {
type tickMsg time.Time type tickMsg time.Time
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction { func New(slug slug.Slug, forwarder Forwarder) Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Interaction{ return &interaction{
channel: nil, channel: nil,
slugManager: slugManager, slug: slug,
forwarder: forwarder, forwarder: forwarder,
lifecycle: nil, lifecycle: nil,
updateClientSlug: nil, sessionRegistry: nil,
program: nil, program: nil,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
} }
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
i.sessionRegistry = registry
}
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle i.lifecycle = lifecycle
} }
func (i *Interaction) SetChannel(channel ssh.Channel) { func (i *interaction) SetChannel(channel ssh.Channel) {
i.channel = channel i.channel = channel
} }
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) { func (i *interaction) Stop() {
i.updateClientSlug = modificator
}
func (i *Interaction) Stop() {
if i.cancel != nil { if i.cancel != nil {
i.cancel() i.cancel()
} }
@@ -218,7 +242,13 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter": case "enter":
inputValue := m.slugInput.Value() inputValue := m.slugInput.Value()
if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil { if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
}); err != nil {
m.slugError = err.Error() m.slugError = err.Error()
return m, nil return m, nil
} }
@@ -255,7 +285,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if item.name == "slug" { if item.name == "slug" {
m.showingCommands = false m.showingCommands = false
m.editingSlug = true m.editingSlug = true
m.slugInput.SetValue(m.interaction.slugManager.Get()) m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus() m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" { } else if item.name == "tunnel-type" {
@@ -287,7 +317,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil return m, nil
} }
func (i *Interaction) Redraw() { func (i *interaction) Redraw() {
if i.program != nil { if i.program != nil {
i.program.Send(tea.ClearScreen()) i.program.Send(tea.ClearScreen())
} }
@@ -661,22 +691,32 @@ func (m *model) View() string {
MarginBottom(boxMargin). MarginBottom(boxMargin).
Width(boxMaxWidth) Width(boxMaxWidth)
urlDisplay := m.getTunnelURL() authenticatedUser := m.interaction.lifecycle.User()
if shouldUseCompactLayout(m.width, 80) && len(urlDisplay) > m.width-20 {
maxLen := m.width - 25 userInfoStyle := lipgloss.NewStyle().
if maxLen > 10 { Foreground(lipgloss.Color("#FAFAFA")).
urlDisplay = truncateString(urlDisplay, maxLen) Bold(true)
}
} sectionHeaderStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Bold(true)
addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA"))
var infoContent string var infoContent string
if shouldUseCompactLayout(m.width, 70) { if shouldUseCompactLayout(m.width, 70) {
infoContent = fmt.Sprintf("🌐 %s", urlBoxStyle.Render(urlDisplay)) infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
} else if isCompact { userInfoStyle.Render(authenticatedUser),
infoContent = fmt.Sprintf("🌐 Forwarding to:\n\n %s", urlBoxStyle.Render(urlDisplay)) sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
} else { } else {
infoContent = fmt.Sprintf("🌐 F O R W A R D I N G T O:\n\n %s", urlBoxStyle.Render(urlDisplay)) infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
} }
b.WriteString(responsiveInfoBox.Render(infoContent)) b.WriteString(responsiveInfoBox.Render(infoContent))
b.WriteString("\n") b.WriteString("\n")
@@ -727,7 +767,10 @@ func (m *model) View() string {
return b.String() return b.String()
} }
func (i *Interaction) Start() { func (i *interaction) Start() {
if i.mode == types.HEADLESS {
return
}
lipgloss.SetColorProfile(termenv.TrueColor) lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost") domain := config.Getenv("DOMAIN", "localhost")
@@ -736,8 +779,8 @@ func (i *Interaction) Start() {
protocol = "https" protocol = "https"
} }
tunnelType := i.forwarder.GetTunnelType() tunnelType := i.forwarder.TunnelType()
port := i.forwarder.GetForwardedPort() port := i.forwarder.ForwardedPort()
items := []list.Item{ items := []list.Item{
commandItem{name: "slug", desc: "Set custom subdomain"}, commandItem{name: "slug", desc: "Set custom subdomain"},
+57 -41
View File
@@ -15,103 +15,119 @@ import (
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType TunnelType() types.TunnelType
GetForwardedPort() uint16 ForwardedPort() uint16
} }
type Lifecycle struct { type SessionRegistry interface {
Remove(key types.SessionKey)
}
type lifecycle struct {
status types.Status status types.Status
conn ssh.Conn conn ssh.Conn
channel ssh.Channel channel ssh.Channel
forwarder Forwarder forwarder Forwarder
slugManager slug.Manager sessionRegistry SessionRegistry
unregisterClient func(slug string) slug slug.Slug
startedAt time.Time startedAt time.Time
user string
} }
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
return &Lifecycle{ return &lifecycle{
status: types.INITIALIZING, status: types.INITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
slugManager: slugManager, slug: slugManager,
unregisterClient: nil, sessionRegistry: nil,
startedAt: time.Now(), startedAt: time.Now(),
user: user,
} }
} }
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
l.unregisterClient = unregisterClient l.sessionRegistry = registry
} }
type SessionLifecycle interface { type Lifecycle interface {
Close() error Connection() ssh.Conn
SetStatus(status types.Status) Channel() ssh.Channel
GetConnection() ssh.Conn User() string
GetChannel() ssh.Channel
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetUnregisterClient(unregisterClient func(slug string)) SetSessionRegistry(registry SessionRegistry)
SetStatus(status types.Status)
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
Close() error
} }
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *lifecycle) User() string {
return l.user
}
func (l *lifecycle) Channel() ssh.Channel {
return l.channel return l.channel
} }
func (l *Lifecycle) SetChannel(channel ssh.Channel) { func (l *lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel l.channel = channel
} }
func (l *Lifecycle) GetConnection() ssh.Conn { func (l *lifecycle) Connection() ssh.Conn {
return l.conn return l.conn
} }
func (l *Lifecycle) SetStatus(status types.Status) { func (l *lifecycle) SetStatus(status types.Status) {
l.status = status l.status = status
if status == types.RUNNING && l.startedAt.IsZero() { if status == types.RUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now() l.startedAt = time.Now()
} }
} }
func (l *Lifecycle) Close() error { func (l *lifecycle) Close() error {
err := l.forwarder.Close() var firstErr error
if err != nil && !errors.Is(err, net.ErrClosed) { tunnelType := l.forwarder.TunnelType()
return err
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
firstErr = err
} }
if l.channel != nil { if l.channel != nil {
err := l.channel.Close() if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
if err != nil && !errors.Is(err, io.EOF) { if firstErr == nil {
return err firstErr = err
}
} }
} }
if l.conn != nil { if l.conn != nil {
err := l.conn.Close() if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if err != nil && !errors.Is(err, net.ErrClosed) { if firstErr == nil {
return err firstErr = err
}
} }
} }
clientSlug := l.slugManager.Get() clientSlug := l.slug.String()
if clientSlug != "" { key := types.SessionKey{
l.unregisterClient(clientSlug) Id: clientSlug,
Type: tunnelType,
} }
l.sessionRegistry.Remove(key)
if l.forwarder.GetTunnelType() == types.TCP { if tunnelType == types.TCP {
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false) if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
if err != nil { firstErr = err
return err
} }
} }
return nil return firstErr
} }
func (l *Lifecycle) IsActive() bool { func (l *lifecycle) IsActive() bool {
return l.status == types.RUNNING return l.status == types.RUNNING
} }
func (l *Lifecycle) StartedAt() time.Time { func (l *lifecycle) StartedAt() time.Time {
return l.startedAt return l.startedAt
} }
+62 -43
View File
@@ -3,129 +3,148 @@ package session
import ( import (
"fmt" "fmt"
"sync" "sync"
"tunnel_pls/types"
) )
type Key = types.SessionKey
type Registry interface { type Registry interface {
Get(slug string) (session *SSHSession, err error) Get(key Key) (session Session, err error)
Update(oldSlug, newSlug string) error GetWithUser(user string, key Key) (session Session, err error)
Register(slug string, session *SSHSession) (success bool) Update(user string, oldKey, newKey Key) error
Remove(slug string) Register(key Key, session Session) (success bool)
GetAllSessionFromUser(user string) []*SSHSession Remove(key Key)
GetAllSessionFromUser(user string) []Session
} }
type registry struct { type registry struct {
mu sync.RWMutex mu sync.RWMutex
byUser map[string]map[string]*SSHSession byUser map[string]map[Key]Session
slugIndex map[string]string slugIndex map[Key]string
} }
func NewRegistry() Registry { func NewRegistry() Registry {
return &registry{ return &registry{
byUser: make(map[string]map[string]*SSHSession), byUser: make(map[string]map[Key]Session),
slugIndex: make(map[string]string), slugIndex: make(map[Key]string),
} }
} }
func (r *registry) Get(slug string) (session *SSHSession, err error) { func (r *registry) Get(key Key) (session Session, err error) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
userID, ok := r.slugIndex[slug] userID, ok := r.slugIndex[key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("session not found")
} }
client, ok := r.byUser[userID][slug] client, ok := r.byUser[userID][key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("session not found")
} }
return client, nil return client, nil
} }
func (r *registry) Update(oldSlug, newSlug string) error { func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
if isForbiddenSlug(newSlug) { r.mu.RLock()
defer r.mu.RUnlock()
client, ok := r.byUser[user][key]
if !ok {
return nil, fmt.Errorf("session not found")
}
return client, nil
}
func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change")
}
if newKey.Type != types.HTTP {
return fmt.Errorf("non http tunnel cannot change slug")
}
if isForbiddenSlug(newKey.Id) {
return fmt.Errorf("this subdomain is reserved. Please choose a different one") return fmt.Errorf("this subdomain is reserved. Please choose a different one")
} else if !isValidSlug(newSlug) { }
if !isValidSlug(newKey.Id) {
return fmt.Errorf("invalid subdomain. Follow the rules") return fmt.Errorf("invalid subdomain. Follow the rules")
} }
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
userID, ok := r.slugIndex[oldSlug] if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
if !ok {
return fmt.Errorf("session not found")
}
if _, exists := r.slugIndex[newSlug]; exists && newSlug != oldSlug {
return fmt.Errorf("someone already uses this subdomain") return fmt.Errorf("someone already uses this subdomain")
} }
client, ok := r.byUser[user][oldKey]
client, ok := r.byUser[userID][oldSlug]
if !ok { if !ok {
return fmt.Errorf("session not found") return fmt.Errorf("session not found")
} }
delete(r.byUser[userID], oldSlug) delete(r.byUser[user], oldKey)
delete(r.slugIndex, oldSlug) delete(r.slugIndex, oldKey)
client.slugManager.Set(newSlug) client.Slug().Set(newKey.Id)
r.slugIndex[newSlug] = userID r.slugIndex[newKey] = user
if r.byUser[userID] == nil { if r.byUser[user] == nil {
r.byUser[userID] = make(map[string]*SSHSession) r.byUser[user] = make(map[Key]Session)
} }
r.byUser[userID][newSlug] = client r.byUser[user][newKey] = client
return nil return nil
} }
func (r *registry) Register(slug string, session *SSHSession) (success bool) { func (r *registry) Register(key Key, session Session) (success bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.slugIndex[slug]; exists { if _, exists := r.slugIndex[key]; exists {
return false return false
} }
userID := session.userID userID := session.Lifecycle().User()
if r.byUser[userID] == nil { if r.byUser[userID] == nil {
r.byUser[userID] = make(map[string]*SSHSession) r.byUser[userID] = make(map[Key]Session)
} }
r.byUser[userID][slug] = session r.byUser[userID][key] = session
r.slugIndex[slug] = userID r.slugIndex[key] = userID
return true return true
} }
func (r *registry) GetAllSessionFromUser(user string) []*SSHSession { func (r *registry) GetAllSessionFromUser(user string) []Session {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
m := r.byUser[user] m := r.byUser[user]
if len(m) == 0 { if len(m) == 0 {
return []*SSHSession{} return []Session{}
} }
sessions := make([]*SSHSession, 0, len(m)) sessions := make([]Session, 0, len(m))
for _, s := range m { for _, s := range m {
sessions = append(sessions, s) sessions = append(sessions, s)
} }
return sessions return sessions
} }
func (r *registry) Remove(slug string) { func (r *registry) Remove(key Key) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
userID, ok := r.slugIndex[slug] userID, ok := r.slugIndex[key]
if !ok { if !ok {
return return
} }
delete(r.byUser[userID], slug) delete(r.byUser[userID], key)
if len(r.byUser[userID]) == 0 { if len(r.byUser[userID]) == 0 {
delete(r.byUser, userID) delete(r.byUser, userID)
} }
delete(r.slugIndex, slug) delete(r.slugIndex, key)
} }
func isValidSlug(slug string) bool { func isValidSlug(slug string) bool {
+110 -72
View File
@@ -9,67 +9,11 @@ import (
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
}
type SSHSession struct {
initialReq <-chan *ssh.Request
sshReqChannel <-chan ssh.NewChannel
lifecycle lifecycle.SessionLifecycle
interaction interaction.Controller
forwarder forwarder.ForwardingController
slugManager slug.Manager
registry Registry
userID string
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
return s.lifecycle
}
func (s *SSHSession) GetInteraction() interaction.Controller {
return s.interaction
}
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
return s.forwarder
}
func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(sessionRegistry.Update)
forwarderManager.SetLifecycle(lifecycleManager)
lifecycleManager.SetUnregisterClient(sessionRegistry.Remove)
return &SSHSession{
initialReq: forwardingReq,
sshReqChannel: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
userID: userID,
}
}
type Detail struct { type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"` ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"` Slug string `json:"slug,omitempty"`
@@ -78,18 +22,94 @@ type Detail struct {
StartedAt time.Time `json:"started_at,omitempty"` StartedAt time.Time `json:"started_at,omitempty"`
} }
func (s *SSHSession) Detail() Detail { type Session interface {
return Detail{ HandleGlobalRequest(ch <-chan *ssh.Request)
ForwardingType: string(s.forwarder.GetTunnelType()), HandleTCPIPForward(req *ssh.Request)
Slug: s.slugManager.Get(), HandleHTTPForward(req *ssh.Request, port uint16)
UserID: s.userID, HandleTCPForward(req *ssh.Request, addr string, port uint16)
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *Detail
Start() error
}
type session struct {
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle
interaction interaction.Interaction
forwarder forwarder.Forwarder
slug slug.Slug
registry Registry
}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(slugManager)
interactionManager := interaction.New(slugManager, forwarderManager)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &session{
initialReq: initialReq,
sshChan: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slug: slugManager,
registry: sessionRegistry,
}
}
func (s *session) Lifecycle() lifecycle.Lifecycle {
return s.lifecycle
}
func (s *session) Interaction() interaction.Interaction {
return s.interaction
}
func (s *session) Forwarder() forwarder.Forwarder {
return s.forwarder
}
func (s *session) Slug() slug.Slug {
return s.slug
}
func (s *session) Detail() *Detail {
var tunnelType string
if s.forwarder.TunnelType() == types.HTTP {
tunnelType = "HTTP"
} else if s.forwarder.TunnelType() == types.TCP {
tunnelType = "TCP"
} else {
tunnelType = "UNKNOWN"
}
return &Detail{
ForwardingType: tunnelType,
Slug: s.slug.String(),
UserID: s.lifecycle.User(),
Active: s.lifecycle.IsActive(), Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(), StartedAt: s.lifecycle.StartedAt(),
} }
} }
func (s *SSHSession) Start() error { func (s *session) Start() error {
channel := <-s.sshReqChannel var channel ssh.NewChannel
var ok bool
select {
case channel, ok = <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
ch, reqs, err := channel.Accept() ch, reqs, err := channel.Accept()
if err != nil { if err != nil {
log.Printf("failed to accept channel: %v", err) log.Printf("failed to accept channel: %v", err)
@@ -97,23 +117,30 @@ func (s *SSHSession) Start() error {
} }
go s.HandleGlobalRequest(reqs) go s.HandleGlobalRequest(reqs)
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
}
tcpipReq := s.waitForTCPIPForward() tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil { if tcpipReq == nil {
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))) err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
if err != nil { if err != nil {
return err return err
} }
if err := s.lifecycle.Close(); err != nil { if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return fmt.Errorf("no forwarding Request") return fmt.Errorf("no forwarding Request")
} }
s.lifecycle.SetChannel(ch) if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" {
s.interaction.SetChannel(ch) if err := tcpipReq.Reply(false, nil); err != nil {
log.Printf("cannot reply to tcpip req: %s\n", err)
s.HandleTCPIPForward(tcpipReq) return err
}
if err := s.lifecycle.Close(); err != nil { if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
return err return err
@@ -121,7 +148,18 @@ func (s *SSHSession) Start() error {
return nil return nil
} }
func (s *SSHSession) waitForTCPIPForward() *ssh.Request { s.HandleTCPIPForward(tcpipReq)
s.interaction.Start()
s.lifecycle.Connection().Wait()
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
}
return nil
}
func (s *session) waitForTCPIPForward() *ssh.Request {
select { select {
case req, ok := <-s.initialReq: case req, ok := <-s.initialReq:
if !ok { if !ok {
+7 -7
View File
@@ -1,24 +1,24 @@
package slug package slug
type Manager interface { type Slug interface {
Get() string String() string
Set(slug string) Set(slug string)
} }
type manager struct { type slug struct {
slug string slug string
} }
func NewManager() Manager { func New() Slug {
return &manager{ return &slug{
slug: "", slug: "",
} }
} }
func (s *manager) Get() string { func (s *slug) String() string {
return s.slug return s.slug
} }
func (s *manager) Set(slug string) { func (s *slug) Set(slug string) {
s.slug = slug s.slug = slug
} }
+19 -7
View File
@@ -1,20 +1,32 @@
package types package types
type Status string type Status int
const ( const (
INITIALIZING Status = "INITIALIZING" INITIALIZING Status = iota
RUNNING Status = "RUNNING" RUNNING
SETUP Status = "SETUP"
) )
type TunnelType string type Mode int
const ( const (
HTTP TunnelType = "HTTP" INTERACTIVE Mode = iota
TCP TunnelType = "TCP" HEADLESS
) )
type TunnelType int
const (
UNKNOWN TunnelType = iota
HTTP
TCP
)
type SessionKey struct {
Id string
Type TunnelType
}
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" + "Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" + "Content-Type: text/plain\r\n\r\n" +