6 Commits

Author SHA1 Message Date
5b603d8317 feat: implement sessions request from grpc server
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 4m7s
2026-01-03 21:17:01 +07:00
8fd9f8b567 feat: implement sessions request from grpc server
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Has been cancelled
2026-01-03 20:06:14 +07:00
30e84ac3b7 feat: implement get sessions by user 2026-01-02 22:58:54 +07:00
fd6ffc2500 feat(grpc): integrate slug edit handling 2026-01-02 18:27:48 +07:00
e1cd4ed981 WIP: gRPC integration, initial implementation 2026-01-01 21:03:17 +07:00
96d2b88f95 WIP: gRPC integration, initial implementation 2026-01-01 21:01:15 +07:00
18 changed files with 602 additions and 748 deletions
+21
View File
@@ -0,0 +1,21 @@
name: renovate
on:
schedule:
- cron: "0 0 * * *"
push:
branches:
- staging
jobs:
renovate:
runs-on: ubuntu-latest
container: git.fossy.my.id/renovate-clanker/renovate:latest
steps:
- uses: actions/checkout@v6
- run: renovate
env:
RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js
LOG_LEVEL: "debug"
RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }}
GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }}
+3 -3
View File
@@ -3,8 +3,8 @@ module tunnel_pls
go 1.25.5 go 1.25.5
require ( require (
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0
github.com/caddyserver/certmagic v0.25.1 github.com/caddyserver/certmagic v0.25.0
github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/lipgloss v1.1.0
@@ -19,7 +19,7 @@ require (
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/caddyserver/zerossl v0.1.4 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.3 // indirect github.com/charmbracelet/x/ansi v0.11.3 // indirect
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
+2 -10
View File
@@ -1,9 +1,5 @@
git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 h1:RhcBKUG41/om4jgN+iF/vlY/RojTeX1QhBa4p4428ec= git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0 h1:BS1dJU3wa2ILgTGwkV95Knle0il0OQtErGqyb6xV7SU=
git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 h1:tpJSKjaSmV+vxxbVx6qnStjxFVXjj2M0rygWXxLb99o=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
@@ -12,12 +8,8 @@ github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWp
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic= github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic=
github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA= github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA=
github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY=
github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY=
github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw=
github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
+114 -158
View File
@@ -9,7 +9,6 @@ import (
"log" "log"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/types"
"tunnel_pls/session" "tunnel_pls/session"
@@ -40,9 +39,9 @@ type Client struct {
conn *grpc.ClientConn conn *grpc.ClientConn
config *GrpcConfig config *GrpcConfig
sessionRegistry session.Registry sessionRegistry session.Registry
slugService proto.SlugChangeClient
eventService proto.EventServiceClient eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient authorizeConnectionService proto.UserServiceClient
closing bool
} }
func DefaultConfig() *GrpcConfig { func DefaultConfig() *GrpcConfig {
@@ -114,12 +113,14 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error)
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err) return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err)
} }
slugService := proto.NewSlugChangeClient(conn)
eventService := proto.NewEventServiceClient(conn) eventService := proto.NewEventServiceClient(conn)
authorizeConnectionService := proto.NewUserServiceClient(conn) authorizeConnectionService := proto.NewUserServiceClient(conn)
return &Client{ return &Client{
conn: conn, conn: conn,
config: config, config: config,
slugService: slugService,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
eventService: eventService, eventService: eventService,
authorizeConnectionService: authorizeConnectionService, authorizeConnectionService: authorizeConnectionService,
@@ -154,17 +155,16 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
for { for {
subscribe, err := c.eventService.Subscribe(ctx) subscribe, err := c.eventService.Subscribe(ctx)
if err != nil { if err != nil {
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil { if !isConnectionError(err) {
return err return err
} }
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated { if status.Code(err) == codes.Unauthenticated {
return err return err
} }
if err = wait(); err != nil { if err := wait(); err != nil {
return err return err
} }
growBackoff() growBackoff()
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
continue continue
} }
@@ -180,8 +180,8 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
if err != nil { if err != nil {
log.Println("Authentication failed to send to gRPC server:", err) log.Println("Authentication failed to send to gRPC server:", err)
if c.isConnectionError(err) { if isConnectionError(err) {
if err = wait(); err != nil { if err := wait(); err != nil {
return err return err
} }
growBackoff() growBackoff()
@@ -193,12 +193,9 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
backoff = baseBackoff backoff = baseBackoff
if err = c.processEventStream(subscribe); err != nil { if err = c.processEventStream(subscribe); err != nil {
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil { if isConnectionError(err) {
return err
}
if c.isConnectionError(err) {
log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
if err = wait(); err != nil { if err := wait(); err != nil {
return err return err
} }
growBackoff() growBackoff()
@@ -210,155 +207,118 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
} }
func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error { func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error {
handlers := c.eventHandlers(subscribe)
for { for {
recv, err := subscribe.Recv() recv, err := subscribe.Recv()
if err != nil { if err != nil {
return err if isConnectionError(err) {
} log.Printf("connection error receiving from gRPC server: %v", err)
return err
handler, ok := handlers[recv.GetType()] }
if !ok { if status.Code(err) == codes.Unauthenticated {
log.Printf("Unknown event type received: %v", recv.GetType()) log.Printf("Authentication failed: %v", err)
return err
}
log.Printf("non-connection receive error from gRPC server: %v", err)
continue continue
} }
switch recv.GetType() {
if err = handler(recv); err != nil { case proto.EventType_SLUG_CHANGE:
return err oldSlug := recv.GetSlugEvent().GetOld()
newSlug := recv.GetSlugEvent().GetNew()
sess, err := c.sessionRegistry.Get(oldSlug)
if err != nil {
errSend := subscribe.Send(&proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{
Success: false,
Message: err.Error(),
},
},
})
if errSend != nil {
if isConnectionError(errSend) {
log.Printf("connection error sending slug change failure: %v", errSend)
return errSend
}
log.Printf("non-connection send error for slug change failure: %v", errSend)
}
continue
}
err = c.sessionRegistry.Update(oldSlug, newSlug)
if err != nil {
errSend := subscribe.Send(&proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{
Success: false,
Message: err.Error(),
},
},
})
if errSend != nil {
if isConnectionError(errSend) {
log.Printf("connection error sending slug change failure: %v", errSend)
return errSend
}
log.Printf("non-connection send error for slug change failure: %v", errSend)
}
continue
}
sess.GetInteraction().Redraw()
err = subscribe.Send(&proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{
Success: true,
Message: "",
},
},
})
if err != nil {
if isConnectionError(err) {
log.Printf("connection error sending slug change success: %v", err)
return err
}
log.Printf("non-connection send error for slug change success: %v", err)
continue
}
case proto.EventType_GET_SESSIONS:
sessions := c.sessionRegistry.GetAllSessionFromUser(recv.GetGetSessionsEvent().GetIdentity())
var details []*proto.Detail
for _, ses := range sessions {
detail := ses.Detail()
details = append(details, &proto.Detail{
Node: config.Getenv("domain", "localhost"),
ForwardingType: detail.ForwardingType,
Slug: detail.Slug,
UserId: detail.UserID,
Active: detail.Active,
StartedAt: timestamppb.New(detail.StartedAt),
})
}
err = subscribe.Send(&proto.Node{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Node_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsResponse{
Details: details,
},
},
})
if err != nil {
if isConnectionError(err) {
log.Printf("connection error sending sessions success: %v", err)
return err
}
log.Printf("non-connection send error for sessions success: %v", err)
continue
}
default:
log.Printf("Unknown event type received: %v", recv.GetType())
} }
} }
} }
func (c *Client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error {
return map[proto.EventType]func(*proto.Events) error{
proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) },
proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) },
proto.EventType_TERMINATE_SESSION: func(evt *proto.Events) error { return c.handleTerminateSession(subscribe, evt) },
}
}
func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
slugEvent := evt.GetSlugEvent()
user := slugEvent.GetUser()
oldSlug := slugEvent.GetOld()
newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP})
if err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
},
}, "slug change failure response")
}
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
},
}, "slug change failure response")
}
userSession.Interaction().Redraw()
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: true, Message: ""},
},
}, "slug change success response")
}
func (c *Client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity())
var details []*proto.Detail
for _, ses := range sessions {
detail := ses.Detail()
details = append(details, &proto.Detail{
Node: config.Getenv("DOMAIN", "localhost"),
ForwardingType: detail.ForwardingType,
Slug: detail.Slug,
UserId: detail.UserID,
Active: detail.Active,
StartedAt: timestamppb.New(detail.StartedAt),
})
}
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Node_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
},
}, "send get sessions response")
}
func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
terminate := evt.GetTerminateSessionEvent()
user := terminate.GetUser()
slug := terminate.GetSlug()
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
if err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session invalid tunnel type")
}
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
if err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session fetch failed")
}
if err = userSession.Lifecycle().Close(); err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session close failed")
}
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: true, Message: ""},
},
}, "terminate session success response")
}
func (c *Client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
if err := subscribe.Send(node); err != nil {
if c.isConnectionError(err) {
return err
}
log.Printf("%s: %v", context, err)
}
return nil
}
func (c *Client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t {
case proto.TunnelType_HTTP:
return types.HTTP, nil
case proto.TunnelType_TCP:
return types.TCP, nil
default:
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received")
}
}
func (c *Client) GetConnection() *grpc.ClientConn { func (c *Client) GetConnection() *grpc.ClientConn {
return c.conn return c.conn
} }
@@ -378,7 +338,6 @@ func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bo
func (c *Client) Close() error { func (c *Client) Close() error {
if c.conn != nil { if c.conn != nil {
log.Printf("Closing gRPC connection to %s", c.config.Address) log.Printf("Closing gRPC connection to %s", c.config.Address)
c.closing = true
return c.conn.Close() return c.conn.Close()
} }
return nil return nil
@@ -402,10 +361,7 @@ func (c *Client) GetConfig() *GrpcConfig {
return c.config return c.config
} }
func (c *Client) isConnectionError(err error) bool { func isConnectionError(err error) bool {
if c.closing {
return false
}
if err == nil { if err == nil {
return false return false
} }
+6 -16
View File
@@ -13,7 +13,7 @@ type Manager interface {
AddPortRange(startPort, endPort uint16) error AddPortRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool) GetUnassignedPort() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error SetPortStatus(port uint16, assigned bool) error
ClaimPort(port uint16) (claimed bool) GetPortStatus(port uint16) (bool, bool)
} }
type manager struct { type manager struct {
@@ -74,6 +74,7 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) {
for _, port := range pm.sortedPorts { for _, port := range pm.sortedPorts {
if !pm.ports[port] { if !pm.ports[port] {
pm.ports[port] = true
return port, true return port, true
} }
} }
@@ -88,21 +89,10 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil return nil
} }
func (pm *manager) ClaimPort(port uint16) (claimed bool) { func (pm *manager) GetPortStatus(port uint16) (bool, bool) {
pm.mu.Lock() pm.mu.RLock()
defer pm.mu.Unlock() defer pm.mu.RUnlock()
status, exists := pm.ports[port] status, exists := pm.ports[port]
return status, exists
if exists && status {
return false
}
if !exists {
pm.ports[port] = true
return true
}
pm.ports[port] = true
return true
} }
+27 -41
View File
@@ -7,9 +7,7 @@ import (
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal"
"strings" "strings"
"syscall"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
@@ -70,14 +68,10 @@ func main() {
sshConfig.AddHostKey(private) sshConfig.AddHostKey(private)
sessionRegistry := session.NewRegistry() sessionRegistry := session.NewRegistry()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error, 2)
shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient *client.Client var grpcClient *client.Client
var cancel context.CancelFunc = func() {}
var ctx context.Context = context.Background()
if isNodeMode { if isNodeMode {
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost") grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
grpcPort := config.Getenv("GRPC_PORT", "8080") grpcPort := config.Getenv("GRPC_PORT", "8080")
@@ -85,9 +79,10 @@ func main() {
nodeToken := config.Getenv("NODE_TOKEN", "") nodeToken := config.Getenv("NODE_TOKEN", "")
if nodeToken == "" { if nodeToken == "" {
log.Fatalf("NODE_TOKEN is required in node mode") log.Fatalf("NODE_TOKEN is required in node mode")
return
} }
c, err := client.New(&client.GrpcConfig{ grpcClient, err = client.New(&client.GrpcConfig{
Address: grpcAddr, Address: grpcAddr,
UseTLS: false, UseTLS: false,
InsecureSkipVerify: false, InsecureSkipVerify: false,
@@ -96,46 +91,37 @@ func main() {
MaxRetries: 3, MaxRetries: 3,
}, sessionRegistry) }, sessionRegistry)
if err != nil { if err != nil {
log.Fatalf("failed to create grpc client: %v", err) return
} }
grpcClient = c defer func(grpcClient *client.Client) {
err := grpcClient.Close()
if err != nil {
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) }
if err := grpcClient.CheckServerHealth(healthCtx); err != nil { }(grpcClient)
healthCancel()
log.Fatalf("gRPC health check failed: %v", err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
err = grpcClient.CheckServerHealth(ctx)
if err != nil {
log.Fatalf("gRPC health check failed: %s", err)
return
} }
healthCancel() cancel()
ctx, cancel = context.WithCancel(context.Background())
go func() { go func() {
identity := config.Getenv("DOMAIN", "localhost") identity := config.Getenv("DOMAIN", "localhost")
if err := grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil { err = grpcClient.SubscribeEvents(ctx, identity, nodeToken)
errChan <- fmt.Errorf("failed to subscribe to events: %w", err) if err != nil {
return
} }
}() }()
} }
go func() { app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient)
app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient) if err != nil {
if err != nil { log.Fatalf("Failed to start server: %s", err)
errChan <- fmt.Errorf("failed to start server: %s", err)
return
}
app.Start()
}()
select {
case err := <-errChan:
log.Printf("error happen : %s", err)
case sig := <-shutdownChan:
log.Printf("received signal %s, shutting down", sig)
} }
app.Start()
cancel() cancel()
if grpcClient != nil {
if err := grpcClient.Close(); err != nil {
log.Printf("failed to close grpc conn : %s", err)
}
}
} }
+8
View File
@@ -0,0 +1,8 @@
module.exports = {
"endpoint": "https://git.fossy.my.id/api/v1",
"gitAuthor": "Renovate-Clanker <renovate-bot@fossy.my.id>",
"platform": "gitea",
"onboardingConfigFileName": "renovate.json",
"autodiscover": true,
"optimizeForDisabled": true,
};
+7 -11
View File
@@ -13,7 +13,6 @@ import (
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/types"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -314,10 +313,7 @@ func (hs *httpServer) handler(conn net.Conn) {
return return
} }
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ sshSession, err := hs.sessionRegistry.Get(slug)
Id: slug,
Type: types.HTTP,
})
if err != nil { if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
@@ -335,8 +331,8 @@ func (hs *httpServer) handler(conn net.Conn) {
return return
} }
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
@@ -346,7 +342,7 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -357,14 +353,14 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
case result := <-resultChan: case result := <-resultChan:
if result.err != nil { if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
channel = result.channel channel = result.channel
reqs = result.reqs reqs = result.reqs
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel") log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter()) sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
@@ -390,6 +386,6 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
return return
} }
sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return return
} }
+1 -5
View File
@@ -9,7 +9,6 @@ import (
"net" "net"
"strings" "strings"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/types"
) )
func (hs *httpServer) ListenAndServeTLS() error { func (hs *httpServer) ListenAndServeTLS() error {
@@ -90,10 +89,7 @@ func (hs *httpServer) handlerTLS(conn net.Conn) {
return return
} }
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ sshSession, err := hs.sessionRegistry.Get(slug)
Id: slug,
Type: types.HTTP,
})
if err != nil { if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
+11 -13
View File
@@ -2,11 +2,9 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
"time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/session" "tunnel_pls/session"
@@ -66,31 +64,31 @@ func (s *Server) Start() {
func (s *Server) handleConnection(conn net.Conn) { func (s *Server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
defer func(sshConn *ssh.ServerConn) {
err = sshConn.Close()
if err != nil {
log.Printf("failed to close SSH server: %v", err)
}
}(sshConn)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err = conn.Close() err := conn.Close()
if err != nil { if err != nil {
log.Printf("failed to close SSH connection: %v", err) log.Printf("failed to close SSH connection: %v", err)
return return
} }
return return
} }
ctx := context.Background()
defer func(sshConn *ssh.ServerConn) { log.Println("SSH connection established:", sshConn.User())
err = sshConn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("failed to close SSH server: %v", err)
}
}(sshConn)
user := "UNAUTHORIZED" user := "UNAUTHORIZED"
if s.grpcClient != nil { if s.grpcClient != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
_, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User()) _, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
user = u user = u
cancel()
} }
log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user)
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
+30 -30
View File
@@ -30,50 +30,50 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
return io.CopyBuffer(dst, src, buf) return io.CopyBuffer(dst, src, buf)
} }
type forwarder struct { type Forwarder struct {
listener net.Listener listener net.Listener
tunnelType types.TunnelType tunnelType types.TunnelType
forwardedPort uint16 forwardedPort uint16
slug slug.Slug slugManager slug.Manager
lifecycle Lifecycle lifecycle Lifecycle
} }
func New(slug slug.Slug) Forwarder { func NewForwarder(slugManager slug.Manager) *Forwarder {
return &forwarder{ return &Forwarder{
listener: nil, listener: nil,
tunnelType: types.UNKNOWN, tunnelType: "",
forwardedPort: 0, forwardedPort: 0,
slug: slug, slugManager: slugManager,
lifecycle: nil, lifecycle: nil,
} }
} }
type Lifecycle interface { type Lifecycle interface {
Connection() ssh.Conn GetConnection() ssh.Conn
} }
type Forwarder interface { type ForwardingController interface {
AcceptTCPConnections()
SetType(tunnelType types.TunnelType) SetType(tunnelType types.TunnelType)
SetLifecycle(lifecycle Lifecycle) GetTunnelType() types.TunnelType
GetForwardedPort() uint16
SetForwardedPort(port uint16) SetForwardedPort(port uint16)
SetListener(listener net.Listener) SetListener(listener net.Listener)
Listener() net.Listener GetListener() net.Listener
TunnelType() types.TunnelType Close() error
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
AcceptTCPConnections()
Close() error
} }
func (f *forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
f.lifecycle = lifecycle f.lifecycle = lifecycle
} }
func (f *forwarder) AcceptTCPConnections() { func (f *Forwarder) AcceptTCPConnections() {
for { for {
conn, err := f.Listener().Accept() conn, err := f.GetListener().Accept()
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return return
@@ -100,7 +100,7 @@ func (f *forwarder) AcceptTCPConnections() {
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -130,7 +130,7 @@ func (f *forwarder) AcceptTCPConnections() {
} }
} }
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func() { defer func() {
_, err := io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil { if err != nil {
@@ -174,31 +174,31 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
wg.Wait() wg.Wait()
} }
func (f *forwarder) SetType(tunnelType types.TunnelType) { func (f *Forwarder) SetType(tunnelType types.TunnelType) {
f.tunnelType = tunnelType f.tunnelType = tunnelType
} }
func (f *forwarder) TunnelType() types.TunnelType { func (f *Forwarder) GetTunnelType() types.TunnelType {
return f.tunnelType return f.tunnelType
} }
func (f *forwarder) ForwardedPort() uint16 { func (f *Forwarder) GetForwardedPort() uint16 {
return f.forwardedPort return f.forwardedPort
} }
func (f *forwarder) SetForwardedPort(port uint16) { func (f *Forwarder) SetForwardedPort(port uint16) {
f.forwardedPort = port f.forwardedPort = port
} }
func (f *forwarder) SetListener(listener net.Listener) { func (f *Forwarder) SetListener(listener net.Listener) {
f.listener = listener f.listener = listener
} }
func (f *forwarder) Listener() net.Listener { func (f *Forwarder) GetListener() net.Listener {
return f.listener return f.listener
} }
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) { func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse) _, err := dst.Write(types.BadGatewayResponse)
if err != nil { if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err) log.Printf("failed to write Bad Gateway response: %v", err)
@@ -206,20 +206,20 @@ func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
} }
} }
func (f *forwarder) Close() error { func (f *Forwarder) Close() error {
if f.Listener() != nil { if f.GetListener() != nil {
return f.listener.Close() return f.listener.Close()
} }
return nil return nil
} }
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
var buf bytes.Buffer var buf bytes.Buffer
host, originPort := parseAddr(origin.String()) host, originPort := parseAddr(origin.String())
writeSSHString(&buf, "localhost") writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort())) err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
if err != nil { if err != nil {
log.Printf("Failed to write string to buffer: %v", err) log.Printf("Failed to write string to buffer: %v", err)
return nil return nil
+135 -96
View File
@@ -15,7 +15,7 @@ import (
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest { for req := range GlobalRequest {
switch req.Type { switch req.Type {
case "shell", "pty-req": case "shell", "pty-req":
@@ -56,172 +56,211 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
} }
} }
func (s *session) HandleTCPIPForward(req *ssh.Request) { func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected") log.Println("Port forwarding request detected")
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
reader := bytes.NewReader(req.Payload) reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader) addr, err := readSSHString(reader)
if err != nil { if err != nil {
fail(fmt.Sprintf("Failed to read address from payload: %v", err)) log.Println("Failed to read address from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
var rawPortToBind uint32 var rawPortToBind uint32
if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
fail(fmt.Sprintf("Failed to read port from payload: %v", err)) log.Println("Failed to read port from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
if rawPortToBind > 65535 { if rawPortToBind > 65535 {
fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind)) log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
portToBind := uint16(rawPortToBind) portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) { if isBlockedPort(portToBind) {
fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind)) log.Printf("Port %d is blocked or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
switch portToBind { if portToBind == 80 || portToBind == 443 {
case 80, 443:
s.HandleHTTPForward(req, portToBind) s.HandleHTTPForward(req, portToBind)
default: return
s.HandleTCPForward(req, addr, portToBind)
} }
if portToBind == 0 {
unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign
if !success {
log.Println("No available port")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
log.Printf("Port %d is already in use or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
err = portUtil.Default.SetPortStatus(portToBind, true)
if err != nil {
log.Println("Failed to set port status:", err)
return
}
s.HandleTCPForward(req, addr, portToBind)
} }
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) { func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
fail := func(msg string, key *types.SessionKey) { slug := random.GenerateRandomString(20)
log.Println(msg)
if key != nil { if !s.registry.Register(slug, s) {
s.registry.Remove(*key) log.Printf("Failed to register client with slug: %s", slug)
} err := req.Reply(false, nil)
if err := req.Reply(false, nil); err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
} }
}
slug := random.GenerateRandomString(20)
key := types.SessionKey{Id: slug, Type: types.HTTP}
if !s.registry.Register(key, s) {
fail(fmt.Sprintf("Failed to register client with slug: %s", slug), nil)
return return
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key) log.Println("Failed to write port to buffer:", err)
s.registry.Remove(slug)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
log.Printf("HTTP forwarding approved on port: %d", portToBind) log.Printf("HTTP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
fail(fmt.Sprintf("Failed to reply to request: %v", err), &key) log.Println("Failed to reply to request:", err)
s.registry.Remove(slug)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return return
} }
s.forwarder.SetType(types.HTTP) s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug) s.slugManager.Set(slug)
s.lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
s.interaction.Start()
} }
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if port != 0 {
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
}
if listener != nil {
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %v", closeErr)
}
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
}
if portToBind == 0 {
unassigned, ok := portUtil.Default.GetUnassignedPort()
if !ok {
fail("No available port")
return
}
portToBind = unassigned
}
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
return
}
log.Printf("Requested forwarding on %s:%d", addr, portToBind) log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil) log.Printf("Port %d is already in use or restricted", portToBind)
return if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
} log.Printf("Failed to reset port status: %v", setErr)
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} err = req.Reply(false, nil)
if err != nil {
if !s.registry.Register(key, s) { log.Println("Failed to reply to request:", err)
cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil) return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return return
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil { if err != nil {
cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key) log.Println("Failed to write port to buffer:", err)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return return
} }
log.Printf("TCP forwarding approved on port: %d", portToBind) log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes()) err = req.Reply(true, buf.Bytes())
if err != nil { if err != nil {
cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key) log.Println("Failed to reply to request:", err)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return return
} }
s.forwarder.SetType(types.TCP) s.forwarder.SetType(types.TCP)
s.forwarder.SetListener(listener) s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections() go s.forwarder.AcceptTCPConnections()
s.interaction.Start()
} }
func readSSHString(reader *bytes.Reader) (string, error) { func readSSHString(reader *bytes.Reader) (string, error) {
+51 -94
View File
@@ -23,59 +23,35 @@ import (
type Lifecycle interface { type Lifecycle interface {
Close() error Close() error
User() string
} }
type SessionRegistry interface { type Controller interface {
Update(user string, oldKey, newKey types.SessionKey) error
}
type Interaction interface {
Mode() types.Mode
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
SetSessionRegistry(registry SessionRegistry) SetSlugModificator(func(oldSlug, newSlug string) error)
SetMode(m types.Mode)
SetWH(w, h int)
Start() Start()
SetWH(w, h int)
Redraw() Redraw()
Send(message string) error
} }
type Forwarder interface { type Forwarder interface {
Close() error Close() error
TunnelType() types.TunnelType GetTunnelType() types.TunnelType
ForwardedPort() uint16 GetForwardedPort() uint16
} }
type interaction struct { type Interaction struct {
channel ssh.Channel channel ssh.Channel
slug slug.Slug slugManager slug.Manager
forwarder Forwarder forwarder Forwarder
lifecycle Lifecycle lifecycle Lifecycle
sessionRegistry SessionRegistry updateClientSlug func(oldSlug, newSlug string) error
program *tea.Program program *tea.Program
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
mode types.Mode
} }
func (i *interaction) SetMode(m types.Mode) { func (i *Interaction) SetWH(w, h int) {
i.mode = m
}
func (i *interaction) Mode() types.Mode {
return i.mode
}
func (i *interaction) Send(message string) error {
if i.channel != nil {
_, err := i.channel.Write([]byte(message))
return err
}
return nil
}
func (i *interaction) SetWH(w, h int) {
if i.program != nil { if i.program != nil {
i.program.Send(tea.WindowSizeMsg{ i.program.Send(tea.WindowSizeMsg{
Width: w, Width: w,
@@ -103,14 +79,14 @@ type model struct {
commandList list.Model commandList list.Model
slugInput textinput.Model slugInput textinput.Model
slugError string slugError string
interaction *interaction interaction *Interaction
width int width int
height int height int
} }
func (m *model) getTunnelURL() string { func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP { if m.tunnelType == types.HTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain) return buildURL(m.protocol, m.interaction.slugManager.Get(), m.domain)
} }
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
} }
@@ -123,33 +99,33 @@ type keymap struct {
type tickMsg time.Time type tickMsg time.Time
func New(slug slug.Slug, forwarder Forwarder) Interaction { func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &interaction{ return &Interaction{
channel: nil, channel: nil,
slug: slug, slugManager: slugManager,
forwarder: forwarder, forwarder: forwarder,
lifecycle: nil, lifecycle: nil,
sessionRegistry: nil, updateClientSlug: nil,
program: nil, program: nil,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
} }
func (i *interaction) SetSessionRegistry(registry SessionRegistry) { func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
i.sessionRegistry = registry
}
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle i.lifecycle = lifecycle
} }
func (i *interaction) SetChannel(channel ssh.Channel) { func (i *Interaction) SetChannel(channel ssh.Channel) {
i.channel = channel i.channel = channel
} }
func (i *interaction) Stop() { func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) {
i.updateClientSlug = modificator
}
func (i *Interaction) Stop() {
if i.cancel != nil { if i.cancel != nil {
i.cancel() i.cancel()
} }
@@ -242,13 +218,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter": case "enter":
inputValue := m.slugInput.Value() inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{ if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil {
Id: m.interaction.slug.String(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
}); err != nil {
m.slugError = err.Error() m.slugError = err.Error()
return m, nil return m, nil
} }
@@ -285,7 +255,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if item.name == "slug" { if item.name == "slug" {
m.showingCommands = false m.showingCommands = false
m.editingSlug = true m.editingSlug = true
m.slugInput.SetValue(m.interaction.slug.String()) m.slugInput.SetValue(m.interaction.slugManager.Get())
m.slugInput.Focus() m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" { } else if item.name == "tunnel-type" {
@@ -317,7 +287,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, nil return m, nil
} }
func (i *interaction) Redraw() { func (i *Interaction) Redraw() {
if i.program != nil { if i.program != nil {
i.program.Send(tea.ClearScreen()) i.program.Send(tea.ClearScreen())
} }
@@ -691,32 +661,22 @@ func (m *model) View() string {
MarginBottom(boxMargin). MarginBottom(boxMargin).
Width(boxMaxWidth) Width(boxMaxWidth)
authenticatedUser := m.interaction.lifecycle.User() urlDisplay := m.getTunnelURL()
if shouldUseCompactLayout(m.width, 80) && len(urlDisplay) > m.width-20 {
userInfoStyle := lipgloss.NewStyle(). maxLen := m.width - 25
Foreground(lipgloss.Color("#FAFAFA")). if maxLen > 10 {
Bold(true) urlDisplay = truncateString(urlDisplay, maxLen)
}
sectionHeaderStyle := lipgloss.NewStyle(). }
Foreground(lipgloss.Color("#888888")).
Bold(true)
addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA"))
var infoContent string var infoContent string
if shouldUseCompactLayout(m.width, 70) { if shouldUseCompactLayout(m.width, 70) {
infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s", infoContent = fmt.Sprintf("🌐 %s", urlBoxStyle.Render(urlDisplay))
userInfoStyle.Render(authenticatedUser), } else if isCompact {
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"), infoContent = fmt.Sprintf("🌐 Forwarding to:\n\n %s", urlBoxStyle.Render(urlDisplay))
addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
} else { } else {
infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s", infoContent = fmt.Sprintf("🌐 F O R W A R D I N G T O:\n\n %s", urlBoxStyle.Render(urlDisplay))
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
} }
b.WriteString(responsiveInfoBox.Render(infoContent)) b.WriteString(responsiveInfoBox.Render(infoContent))
b.WriteString("\n") b.WriteString("\n")
@@ -767,10 +727,7 @@ func (m *model) View() string {
return b.String() return b.String()
} }
func (i *interaction) Start() { func (i *Interaction) Start() {
if i.mode == types.HEADLESS {
return
}
lipgloss.SetColorProfile(termenv.TrueColor) lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost") domain := config.Getenv("DOMAIN", "localhost")
@@ -779,8 +736,8 @@ func (i *interaction) Start() {
protocol = "https" protocol = "https"
} }
tunnelType := i.forwarder.TunnelType() tunnelType := i.forwarder.GetTunnelType()
port := i.forwarder.ForwardedPort() port := i.forwarder.GetForwardedPort()
items := []list.Item{ items := []list.Item{
commandItem{name: "slug", desc: "Set custom subdomain"}, commandItem{name: "slug", desc: "Set custom subdomain"},
+51 -67
View File
@@ -15,119 +15,103 @@ import (
type Forwarder interface { type Forwarder interface {
Close() error Close() error
TunnelType() types.TunnelType GetTunnelType() types.TunnelType
ForwardedPort() uint16 GetForwardedPort() uint16
} }
type SessionRegistry interface { type Lifecycle struct {
Remove(key types.SessionKey) status types.Status
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
slugManager slug.Manager
unregisterClient func(slug string)
startedAt time.Time
} }
type lifecycle struct { func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
status types.Status return &Lifecycle{
conn ssh.Conn status: types.INITIALIZING,
channel ssh.Channel conn: conn,
forwarder Forwarder channel: nil,
sessionRegistry SessionRegistry forwarder: forwarder,
slug slug.Slug slugManager: slugManager,
startedAt time.Time unregisterClient: nil,
user string startedAt: time.Now(),
}
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
return &lifecycle{
status: types.INITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
slug: slugManager,
sessionRegistry: nil,
startedAt: time.Now(),
user: user,
} }
} }
func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) { func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
l.sessionRegistry = registry l.unregisterClient = unregisterClient
} }
type Lifecycle interface { type SessionLifecycle interface {
Connection() ssh.Conn Close() error
Channel() ssh.Channel
User() string
SetChannel(channel ssh.Channel)
SetSessionRegistry(registry SessionRegistry)
SetStatus(status types.Status) SetStatus(status types.Status)
GetConnection() ssh.Conn
GetChannel() ssh.Channel
SetChannel(channel ssh.Channel)
SetUnregisterClient(unregisterClient func(slug string))
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
Close() error
} }
func (l *lifecycle) User() string { func (l *Lifecycle) GetChannel() ssh.Channel {
return l.user
}
func (l *lifecycle) Channel() ssh.Channel {
return l.channel return l.channel
} }
func (l *lifecycle) SetChannel(channel ssh.Channel) { func (l *Lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel l.channel = channel
} }
func (l *lifecycle) Connection() ssh.Conn { func (l *Lifecycle) GetConnection() ssh.Conn {
return l.conn return l.conn
} }
func (l *lifecycle) SetStatus(status types.Status) { func (l *Lifecycle) SetStatus(status types.Status) {
l.status = status l.status = status
if status == types.RUNNING && l.startedAt.IsZero() { if status == types.RUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now() l.startedAt = time.Now()
} }
} }
func (l *lifecycle) Close() error { func (l *Lifecycle) Close() error {
var firstErr error err := l.forwarder.Close()
tunnelType := l.forwarder.TunnelType() if err != nil && !errors.Is(err, net.ErrClosed) {
return err
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
firstErr = err
} }
if l.channel != nil { if l.channel != nil {
if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { err := l.channel.Close()
if firstErr == nil { if err != nil && !errors.Is(err, io.EOF) {
firstErr = err return err
}
} }
} }
if l.conn != nil { if l.conn != nil {
if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { err := l.conn.Close()
if firstErr == nil { if err != nil && !errors.Is(err, net.ErrClosed) {
firstErr = err return err
}
} }
} }
clientSlug := l.slug.String() clientSlug := l.slugManager.Get()
key := types.SessionKey{ if clientSlug != "" {
Id: clientSlug, l.unregisterClient(clientSlug)
Type: tunnelType,
} }
l.sessionRegistry.Remove(key)
if tunnelType == types.TCP { if l.forwarder.GetTunnelType() == types.TCP {
if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
firstErr = err if err != nil {
return err
} }
} }
return firstErr return nil
} }
func (l *lifecycle) IsActive() bool { func (l *Lifecycle) IsActive() bool {
return l.status == types.RUNNING return l.status == types.RUNNING
} }
func (l *lifecycle) StartedAt() time.Time { func (l *Lifecycle) StartedAt() time.Time {
return l.startedAt return l.startedAt
} }
+47 -66
View File
@@ -3,148 +3,129 @@ package session
import ( import (
"fmt" "fmt"
"sync" "sync"
"tunnel_pls/types"
) )
type Key = types.SessionKey
type Registry interface { type Registry interface {
Get(key Key) (session Session, err error) Get(slug string) (session *SSHSession, err error)
GetWithUser(user string, key Key) (session Session, err error) Update(oldSlug, newSlug string) error
Update(user string, oldKey, newKey Key) error Register(slug string, session *SSHSession) (success bool)
Register(key Key, session Session) (success bool) Remove(slug string)
Remove(key Key) GetAllSessionFromUser(user string) []*SSHSession
GetAllSessionFromUser(user string) []Session
} }
type registry struct { type registry struct {
mu sync.RWMutex mu sync.RWMutex
byUser map[string]map[Key]Session byUser map[string]map[string]*SSHSession
slugIndex map[Key]string slugIndex map[string]string
} }
func NewRegistry() Registry { func NewRegistry() Registry {
return &registry{ return &registry{
byUser: make(map[string]map[Key]Session), byUser: make(map[string]map[string]*SSHSession),
slugIndex: make(map[Key]string), slugIndex: make(map[string]string),
} }
} }
func (r *registry) Get(key Key) (session Session, err error) { func (r *registry) Get(slug string) (session *SSHSession, err error) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
userID, ok := r.slugIndex[key] userID, ok := r.slugIndex[slug]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("session not found")
} }
client, ok := r.byUser[userID][key] client, ok := r.byUser[userID][slug]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("session not found")
} }
return client, nil return client, nil
} }
func (r *registry) GetWithUser(user string, key Key) (session Session, err error) { func (r *registry) Update(oldSlug, newSlug string) error {
r.mu.RLock() if isForbiddenSlug(newSlug) {
defer r.mu.RUnlock()
client, ok := r.byUser[user][key]
if !ok {
return nil, fmt.Errorf("session not found")
}
return client, nil
}
func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change")
}
if newKey.Type != types.HTTP {
return fmt.Errorf("non http tunnel cannot change slug")
}
if isForbiddenSlug(newKey.Id) {
return fmt.Errorf("this subdomain is reserved. Please choose a different one") return fmt.Errorf("this subdomain is reserved. Please choose a different one")
} } else if !isValidSlug(newSlug) {
if !isValidSlug(newKey.Id) {
return fmt.Errorf("invalid subdomain. Follow the rules") return fmt.Errorf("invalid subdomain. Follow the rules")
} }
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { userID, ok := r.slugIndex[oldSlug]
return fmt.Errorf("someone already uses this subdomain")
}
client, ok := r.byUser[user][oldKey]
if !ok { if !ok {
return fmt.Errorf("session not found") return fmt.Errorf("session not found")
} }
delete(r.byUser[user], oldKey) if _, exists := r.slugIndex[newSlug]; exists && newSlug != oldSlug {
delete(r.slugIndex, oldKey) return fmt.Errorf("someone already uses this subdomain")
client.Slug().Set(newKey.Id)
r.slugIndex[newKey] = user
if r.byUser[user] == nil {
r.byUser[user] = make(map[Key]Session)
} }
r.byUser[user][newKey] = client
client, ok := r.byUser[userID][oldSlug]
if !ok {
return fmt.Errorf("session not found")
}
delete(r.byUser[userID], oldSlug)
delete(r.slugIndex, oldSlug)
client.slugManager.Set(newSlug)
r.slugIndex[newSlug] = userID
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[string]*SSHSession)
}
r.byUser[userID][newSlug] = client
return nil return nil
} }
func (r *registry) Register(key Key, session Session) (success bool) { func (r *registry) Register(slug string, session *SSHSession) (success bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.slugIndex[key]; exists { if _, exists := r.slugIndex[slug]; exists {
return false return false
} }
userID := session.Lifecycle().User() userID := session.userID
if r.byUser[userID] == nil { if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]Session) r.byUser[userID] = make(map[string]*SSHSession)
} }
r.byUser[userID][key] = session r.byUser[userID][slug] = session
r.slugIndex[key] = userID r.slugIndex[slug] = userID
return true return true
} }
func (r *registry) GetAllSessionFromUser(user string) []Session { func (r *registry) GetAllSessionFromUser(user string) []*SSHSession {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
m := r.byUser[user] m := r.byUser[user]
if len(m) == 0 { if len(m) == 0 {
return []Session{} return []*SSHSession{}
} }
sessions := make([]Session, 0, len(m)) sessions := make([]*SSHSession, 0, len(m))
for _, s := range m { for _, s := range m {
sessions = append(sessions, s) sessions = append(sessions, s)
} }
return sessions return sessions
} }
func (r *registry) Remove(key Key) { func (r *registry) Remove(slug string) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
userID, ok := r.slugIndex[key] userID, ok := r.slugIndex[slug]
if !ok { if !ok {
return return
} }
delete(r.byUser[userID], key) delete(r.byUser[userID], slug)
if len(r.byUser[userID]) == 0 { if len(r.byUser[userID]) == 0 {
delete(r.byUser, userID) delete(r.byUser, userID)
} }
delete(r.slugIndex, key) delete(r.slugIndex, slug)
} }
func isValidSlug(slug string) bool { func isValidSlug(slug string) bool {
+74 -112
View File
@@ -9,11 +9,67 @@ import (
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
}
type SSHSession struct {
initialReq <-chan *ssh.Request
sshReqChannel <-chan ssh.NewChannel
lifecycle lifecycle.SessionLifecycle
interaction interaction.Controller
forwarder forwarder.ForwardingController
slugManager slug.Manager
registry Registry
userID string
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
return s.lifecycle
}
func (s *SSHSession) GetInteraction() interaction.Controller {
return s.interaction
}
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
return s.forwarder
}
func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(sessionRegistry.Update)
forwarderManager.SetLifecycle(lifecycleManager)
lifecycleManager.SetUnregisterClient(sessionRegistry.Remove)
return &SSHSession{
initialReq: forwardingReq,
sshReqChannel: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
userID: userID,
}
}
type Detail struct { type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"` ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"` Slug string `json:"slug,omitempty"`
@@ -22,136 +78,42 @@ type Detail struct {
StartedAt time.Time `json:"started_at,omitempty"` StartedAt time.Time `json:"started_at,omitempty"`
} }
type Session interface { func (s *SSHSession) Detail() Detail {
HandleGlobalRequest(ch <-chan *ssh.Request) return Detail{
HandleTCPIPForward(req *ssh.Request) ForwardingType: string(s.forwarder.GetTunnelType()),
HandleHTTPForward(req *ssh.Request, port uint16) Slug: s.slugManager.Get(),
HandleTCPForward(req *ssh.Request, addr string, port uint16) UserID: s.userID,
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *Detail
Start() error
}
type session struct {
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle
interaction interaction.Interaction
forwarder forwarder.Forwarder
slug slug.Slug
registry Registry
}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(slugManager)
interactionManager := interaction.New(slugManager, forwarderManager)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &session{
initialReq: initialReq,
sshChan: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slug: slugManager,
registry: sessionRegistry,
}
}
func (s *session) Lifecycle() lifecycle.Lifecycle {
return s.lifecycle
}
func (s *session) Interaction() interaction.Interaction {
return s.interaction
}
func (s *session) Forwarder() forwarder.Forwarder {
return s.forwarder
}
func (s *session) Slug() slug.Slug {
return s.slug
}
func (s *session) Detail() *Detail {
var tunnelType string
if s.forwarder.TunnelType() == types.HTTP {
tunnelType = "HTTP"
} else if s.forwarder.TunnelType() == types.TCP {
tunnelType = "TCP"
} else {
tunnelType = "UNKNOWN"
}
return &Detail{
ForwardingType: tunnelType,
Slug: s.slug.String(),
UserID: s.lifecycle.User(),
Active: s.lifecycle.IsActive(), Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(), StartedAt: s.lifecycle.StartedAt(),
} }
} }
func (s *session) Start() error { func (s *SSHSession) Start() error {
var channel ssh.NewChannel channel := <-s.sshReqChannel
var ok bool ch, reqs, err := channel.Accept()
select { if err != nil {
case channel, ok = <-s.sshChan: log.Printf("failed to accept channel: %v", err)
if !ok { return err
log.Println("Forwarding request channel closed")
return nil
}
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
return err
}
go s.HandleGlobalRequest(reqs)
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
} }
go s.HandleGlobalRequest(reqs)
tcpipReq := s.waitForTCPIPForward() tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil { if tcpipReq == nil {
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) _, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
if err != nil { if err != nil {
return err return err
} }
if err = s.lifecycle.Close(); err != nil { if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return fmt.Errorf("no forwarding Request") return fmt.Errorf("no forwarding Request")
} }
if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" { s.lifecycle.SetChannel(ch)
if err := tcpipReq.Reply(false, nil); err != nil { s.interaction.SetChannel(ch)
log.Printf("cannot reply to tcpip req: %s\n", err)
return err
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
}
return nil
}
s.HandleTCPIPForward(tcpipReq) s.HandleTCPIPForward(tcpipReq)
s.interaction.Start()
s.lifecycle.Connection().Wait()
if err := s.lifecycle.Close(); err != nil { if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
return err return err
@@ -159,7 +121,7 @@ func (s *session) Start() error {
return nil return nil
} }
func (s *session) waitForTCPIPForward() *ssh.Request { func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
select { select {
case req, ok := <-s.initialReq: case req, ok := <-s.initialReq:
if !ok { if !ok {
+7 -7
View File
@@ -1,24 +1,24 @@
package slug package slug
type Slug interface { type Manager interface {
String() string Get() string
Set(slug string) Set(slug string)
} }
type slug struct { type manager struct {
slug string slug string
} }
func New() Slug { func NewManager() Manager {
return &slug{ return &manager{
slug: "", slug: "",
} }
} }
func (s *slug) String() string { func (s *manager) Get() string {
return s.slug return s.slug
} }
func (s *slug) Set(slug string) { func (s *manager) Set(slug string) {
s.slug = slug s.slug = slug
} }
+7 -19
View File
@@ -1,32 +1,20 @@
package types package types
type Status int type Status string
const ( const (
INITIALIZING Status = iota INITIALIZING Status = "INITIALIZING"
RUNNING RUNNING Status = "RUNNING"
SETUP Status = "SETUP"
) )
type Mode int type TunnelType string
const ( const (
INTERACTIVE Mode = iota HTTP TunnelType = "HTTP"
HEADLESS TCP TunnelType = "TCP"
) )
type TunnelType int
const (
UNKNOWN TunnelType = iota
HTTP
TCP
)
type SessionKey struct {
Id string
Type TunnelType
}
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" + "Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" + "Content-Type: text/plain\r\n\r\n" +