chore(restructure): reorganize project layout
- Reorganize internal packages and overall project structure - Update imports and wiring to match the new layout - Separate HTTP parsing and streaming from the server package - Separate middleware from the server package - Separate session registry from the session package - Move HTTP, HTTPS, and TCP servers to the transport package - Session package no longer starts the TCP server directly - Server package no longer starts HTTP/HTTPS servers on initialization - Forwarder no longer handles accepting TCP requests - Move session details to the types package - HTTP/HTTPS initialization is now the responsibility of main
This commit is contained in:
@@ -56,14 +56,14 @@ type Forwarder interface {
|
||||
Listener() net.Listener
|
||||
TunnelType() types.TunnelType
|
||||
ForwardedPort() uint16
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
WriteBadGatewayResponse(dst io.Writer)
|
||||
AcceptTCPConnections()
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
@@ -95,38 +95,6 @@ func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) handleIncomingConnection(conn net.Conn) {
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
|
||||
channel, reqs, err := f.openForwardedChannel(payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go f.HandleConnection(conn, channel, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (f *forwarder) AcceptTCPConnections() {
|
||||
for {
|
||||
conn, err := f.Listener().Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go f.handleIncomingConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func closeWriter(w io.Writer) error {
|
||||
if cw, ok := w.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
@@ -145,12 +113,12 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string)
|
||||
}
|
||||
|
||||
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
|
||||
errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err))
|
||||
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
defer func() {
|
||||
_, err := io.Copy(io.Discard, src)
|
||||
if err != nil {
|
||||
@@ -158,8 +126,6 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
|
||||
@@ -31,11 +31,11 @@ type lifecycle struct {
|
||||
slug slug.Slug
|
||||
startedAt time.Time
|
||||
sessionRegistry SessionRegistry
|
||||
portRegistry portUtil.Registry
|
||||
portRegistry portUtil.Port
|
||||
user string
|
||||
}
|
||||
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||
return &lifecycle{
|
||||
status: types.INITIALIZING,
|
||||
conn: conn,
|
||||
@@ -51,7 +51,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti
|
||||
|
||||
type Lifecycle interface {
|
||||
Connection() ssh.Conn
|
||||
PortRegistry() portUtil.Registry
|
||||
PortRegistry() portUtil.Port
|
||||
User() string
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetStatus(status types.Status)
|
||||
@@ -60,7 +60,7 @@ type Lifecycle interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (l *lifecycle) PortRegistry() portUtil.Registry {
|
||||
func (l *lifecycle) PortRegistry() portUtil.Port {
|
||||
return l.portRegistry
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func (l *lifecycle) Close() error {
|
||||
l.sessionRegistry.Remove(key)
|
||||
|
||||
if tunnelType == types.TCP {
|
||||
if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,309 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"tunnel_pls/types"
|
||||
)
|
||||
|
||||
type Key = types.SessionKey
|
||||
|
||||
type Registry interface {
|
||||
Get(key Key) (session Session, err error)
|
||||
GetWithUser(user string, key Key) (session Session, err error)
|
||||
Update(user string, oldKey, newKey Key) error
|
||||
Register(key Key, session Session) (success bool)
|
||||
Remove(key Key)
|
||||
GetAllSessionFromUser(user string) []Session
|
||||
}
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
byUser map[string]map[Key]Session
|
||||
slugIndex map[Key]string
|
||||
}
|
||||
|
||||
func NewRegistry() Registry {
|
||||
return ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registry) Get(key Key) (session Session, err error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
userID, ok := r.slugIndex[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
client, ok := r.byUser[userID][key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
client, ok := r.byUser[user][key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||
if oldKey.Type != newKey.Type {
|
||||
return fmt.Errorf("tunnel type cannot change")
|
||||
}
|
||||
|
||||
if newKey.Type != types.HTTP {
|
||||
return fmt.Errorf("non http tunnel cannot change slug")
|
||||
}
|
||||
|
||||
if isForbiddenSlug(newKey.Id) {
|
||||
return fmt.Errorf("this subdomain is reserved. Please choose a different one")
|
||||
}
|
||||
|
||||
if !isValidSlug(newKey.Id) {
|
||||
return fmt.Errorf("invalid subdomain. Follow the rules")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||
return fmt.Errorf("someone already uses this subdomain")
|
||||
}
|
||||
client, ok := r.byUser[user][oldKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
delete(r.byUser[user], oldKey)
|
||||
delete(r.slugIndex, oldKey)
|
||||
|
||||
client.Slug().Set(newKey.Id)
|
||||
r.slugIndex[newKey] = user
|
||||
|
||||
if r.byUser[user] == nil {
|
||||
r.byUser[user] = make(map[Key]Session)
|
||||
}
|
||||
r.byUser[user][newKey] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registry) Register(key Key, session Session) (success bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.slugIndex[key]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
userID := session.Lifecycle().User()
|
||||
if r.byUser[userID] == nil {
|
||||
r.byUser[userID] = make(map[Key]Session)
|
||||
}
|
||||
|
||||
r.byUser[userID][key] = session
|
||||
r.slugIndex[key] = userID
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *registry) GetAllSessionFromUser(user string) []Session {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
m := r.byUser[user]
|
||||
if len(m) == 0 {
|
||||
return []Session{}
|
||||
}
|
||||
|
||||
sessions := make([]Session, 0, len(m))
|
||||
for _, s := range m {
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (r *registry) Remove(key Key) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
userID, ok := r.slugIndex[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(r.byUser[userID], key)
|
||||
if len(r.byUser[userID]) == 0 {
|
||||
delete(r.byUser, userID)
|
||||
}
|
||||
delete(r.slugIndex, key)
|
||||
}
|
||||
|
||||
func isValidSlug(slug string) bool {
|
||||
if len(slug) < minSlugLength || len(slug) > maxSlugLength {
|
||||
return false
|
||||
}
|
||||
|
||||
if slug[0] == '-' || slug[len(slug)-1] == '-' {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, c := range slug {
|
||||
if !isValidSlugChar(byte(c)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidSlugChar(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-'
|
||||
}
|
||||
|
||||
func isForbiddenSlug(slug string) bool {
|
||||
_, ok := forbiddenSlugs[slug]
|
||||
return ok
|
||||
}
|
||||
|
||||
var forbiddenSlugs = map[string]struct{}{
|
||||
"ping": {},
|
||||
"staging": {},
|
||||
"admin": {},
|
||||
"root": {},
|
||||
"api": {},
|
||||
"www": {},
|
||||
"support": {},
|
||||
"help": {},
|
||||
"status": {},
|
||||
"health": {},
|
||||
"login": {},
|
||||
"logout": {},
|
||||
"signup": {},
|
||||
"register": {},
|
||||
"settings": {},
|
||||
"config": {},
|
||||
"null": {},
|
||||
"undefined": {},
|
||||
"example": {},
|
||||
"test": {},
|
||||
"dev": {},
|
||||
"system": {},
|
||||
"administrator": {},
|
||||
"dashboard": {},
|
||||
"account": {},
|
||||
"profile": {},
|
||||
"user": {},
|
||||
"users": {},
|
||||
"auth": {},
|
||||
"oauth": {},
|
||||
"callback": {},
|
||||
"webhook": {},
|
||||
"webhooks": {},
|
||||
"static": {},
|
||||
"assets": {},
|
||||
"cdn": {},
|
||||
"mail": {},
|
||||
"email": {},
|
||||
"ftp": {},
|
||||
"ssh": {},
|
||||
"git": {},
|
||||
"svn": {},
|
||||
"blog": {},
|
||||
"news": {},
|
||||
"about": {},
|
||||
"contact": {},
|
||||
"terms": {},
|
||||
"privacy": {},
|
||||
"legal": {},
|
||||
"billing": {},
|
||||
"payment": {},
|
||||
"checkout": {},
|
||||
"cart": {},
|
||||
"shop": {},
|
||||
"store": {},
|
||||
"download": {},
|
||||
"uploads": {},
|
||||
"images": {},
|
||||
"img": {},
|
||||
"css": {},
|
||||
"js": {},
|
||||
"fonts": {},
|
||||
"public": {},
|
||||
"private": {},
|
||||
"internal": {},
|
||||
"external": {},
|
||||
"proxy": {},
|
||||
"cache": {},
|
||||
"debug": {},
|
||||
"metrics": {},
|
||||
"monitoring": {},
|
||||
"graphql": {},
|
||||
"rest": {},
|
||||
"rpc": {},
|
||||
"socket": {},
|
||||
"ws": {},
|
||||
"wss": {},
|
||||
"app": {},
|
||||
"apps": {},
|
||||
"mobile": {},
|
||||
"desktop": {},
|
||||
"embed": {},
|
||||
"widget": {},
|
||||
"docs": {},
|
||||
"documentation": {},
|
||||
"wiki": {},
|
||||
"forum": {},
|
||||
"community": {},
|
||||
"feedback": {},
|
||||
"report": {},
|
||||
"abuse": {},
|
||||
"spam": {},
|
||||
"security": {},
|
||||
"verify": {},
|
||||
"confirm": {},
|
||||
"reset": {},
|
||||
"password": {},
|
||||
"recovery": {},
|
||||
"unsubscribe": {},
|
||||
"subscribe": {},
|
||||
"notifications": {},
|
||||
"alerts": {},
|
||||
"messages": {},
|
||||
"inbox": {},
|
||||
"outbox": {},
|
||||
"sent": {},
|
||||
"draft": {},
|
||||
"trash": {},
|
||||
"archive": {},
|
||||
"search": {},
|
||||
"explore": {},
|
||||
"discover": {},
|
||||
"trending": {},
|
||||
"popular": {},
|
||||
"featured": {},
|
||||
"new": {},
|
||||
"latest": {},
|
||||
"top": {},
|
||||
"best": {},
|
||||
"hot": {},
|
||||
"random": {},
|
||||
"all": {},
|
||||
"any": {},
|
||||
"none": {},
|
||||
"true": {},
|
||||
"false": {},
|
||||
}
|
||||
|
||||
var (
|
||||
minSlugLength = 3
|
||||
maxSlugLength = 20
|
||||
)
|
||||
+19
-18
@@ -12,6 +12,8 @@ import (
|
||||
"tunnel_pls/internal/config"
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/internal/transport"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
@@ -21,14 +23,6 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Detail struct {
|
||||
ForwardingType string `json:"forwarding_type,omitempty"`
|
||||
Slug string `json:"slug,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Active bool `json:"active,omitempty"`
|
||||
StartedAt time.Time `json:"started_at,omitempty"`
|
||||
}
|
||||
|
||||
type Session interface {
|
||||
HandleGlobalRequest(ch <-chan *ssh.Request) error
|
||||
HandleTCPIPForward(req *ssh.Request) error
|
||||
@@ -38,7 +32,7 @@ type Session interface {
|
||||
Interaction() interaction.Interaction
|
||||
Forwarder() forwarder.Forwarder
|
||||
Slug() slug.Slug
|
||||
Detail() *Detail
|
||||
Detail() *types.Detail
|
||||
Start() error
|
||||
}
|
||||
|
||||
@@ -49,12 +43,12 @@ type session struct {
|
||||
interaction interaction.Interaction
|
||||
forwarder forwarder.Forwarder
|
||||
slug slug.Slug
|
||||
registry Registry
|
||||
registry registry.Registry
|
||||
}
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session {
|
||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
|
||||
slugManager := slug.New()
|
||||
forwarderManager := forwarder.New(slugManager, conn)
|
||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
|
||||
@@ -87,7 +81,7 @@ func (s *session) Slug() slug.Slug {
|
||||
return s.slug
|
||||
}
|
||||
|
||||
func (s *session) Detail() *Detail {
|
||||
func (s *session) Detail() *types.Detail {
|
||||
tunnelTypeMap := map[types.TunnelType]string{
|
||||
types.HTTP: "HTTP",
|
||||
types.TCP: "TCP",
|
||||
@@ -97,7 +91,7 @@ func (s *session) Detail() *Detail {
|
||||
tunnelType = "UNKNOWN"
|
||||
}
|
||||
|
||||
return &Detail{
|
||||
return &types.Detail{
|
||||
ForwardingType: tunnelType,
|
||||
Slug: s.slug.String(),
|
||||
UserID: s.lifecycle.User(),
|
||||
@@ -271,7 +265,7 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string,
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort()
|
||||
unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
|
||||
if !ok {
|
||||
return "", 0, fmt.Errorf("no available port")
|
||||
}
|
||||
@@ -328,7 +322,6 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
|
||||
|
||||
if listener != nil {
|
||||
s.forwarder.SetListener(listener)
|
||||
go s.forwarder.AcceptTCPConnections()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -346,7 +339,6 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
||||
case 80, 443:
|
||||
return s.HandleHTTPForward(req, port)
|
||||
default:
|
||||
|
||||
return s.HandleTCPForward(req, address, port)
|
||||
}
|
||||
}
|
||||
@@ -369,11 +361,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
||||
}
|
||||
|
||||
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
|
||||
if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed {
|
||||
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
|
||||
listener, err := tcpServer.Listen()
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
}
|
||||
@@ -387,6 +380,14 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = tcpServer.Serve(listener)
|
||||
if err != nil {
|
||||
log.Printf("Failed serving tcp server: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user