diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 1aad1d2..79c1113 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -9,6 +9,7 @@ import ( "log" "time" "tunnel_pls/internal/config" + "tunnel_pls/types" "tunnel_pls/session" @@ -201,7 +202,6 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string if c.isConnectionError(err) { log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) if err = wait(); err != nil { - fmt.Println(err) return err } growBackoff() @@ -222,7 +222,11 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod case proto.EventType_SLUG_CHANGE: oldSlug := recv.GetSlugEvent().GetOld() newSlug := recv.GetSlugEvent().GetNew() - sess, err := c.sessionRegistry.Get(oldSlug) + var userSession *session.SSHSession + userSession, err = c.sessionRegistry.Get(types.SessionKey{ + Id: oldSlug, + Type: types.HTTP, + }) if err != nil { errSend := subscribe.Send(&proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, @@ -241,7 +245,13 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod } continue } - err = c.sessionRegistry.Update(oldSlug, newSlug) + err = c.sessionRegistry.Update(types.SessionKey{ + Id: oldSlug, + Type: types.HTTP, + }, types.SessionKey{ + Id: newSlug, + Type: types.HTTP, + }) if err != nil { errSend := subscribe.Send(&proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, @@ -260,7 +270,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod } continue } - sess.GetInteraction().Redraw() + userSession.GetInteraction().Redraw() err = subscribe.Send(&proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, Payload: &proto.Node_SlugEventResponse{ diff --git a/server/http.go b/server/http.go index 8add118..2420686 100644 --- a/server/http.go +++ b/server/http.go @@ -13,6 +13,7 @@ import ( "time" "tunnel_pls/internal/config" "tunnel_pls/session" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) @@ -313,7 +314,10 @@ func (hs *httpServer) handler(conn net.Conn) { return } - sshSession, err := hs.sessionRegistry.Get(slug) + sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ + Id: slug, + Type: types.HTTP, + }) if err != nil { _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + diff --git a/server/https.go b/server/https.go index 3c502a5..2758172 100644 --- a/server/https.go +++ b/server/https.go @@ -9,6 +9,7 @@ import ( "net" "strings" "tunnel_pls/internal/config" + "tunnel_pls/types" ) func (hs *httpServer) ListenAndServeTLS() error { @@ -89,7 +90,10 @@ func (hs *httpServer) handlerTLS(conn net.Conn) { return } - sshSession, err := hs.sessionRegistry.Get(slug) + sshSession, err := hs.sessionRegistry.Get(types.SessionKey{ + Id: slug, + Type: types.HTTP, + }) if err != nil { _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + diff --git a/server/server.go b/server/server.go index 0ee111c..9439d90 100644 --- a/server/server.go +++ b/server/server.go @@ -88,7 +88,6 @@ func (s *Server) handleConnection(conn net.Conn) { _, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User()) user = u } - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) err = sshSession.Start() if err != nil { diff --git a/session/handler.go b/session/handler.go index 30458fb..ca080d8 100644 --- a/session/handler.go +++ b/session/handler.go @@ -164,8 +164,9 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { slug := random.GenerateRandomString(20) + key := types.SessionKey{Id: slug, Type: types.HTTP} - if !s.registry.Register(slug, s) { + if !s.registry.Register(key, s) { log.Printf("Failed to register client with slug: %s", slug) err := req.Reply(false, nil) if err != nil { @@ -178,7 +179,7 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { log.Println("Failed to write port to buffer:", err) - s.registry.Remove(slug) + s.registry.Remove(key) err = req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -190,7 +191,7 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) - s.registry.Remove(slug) + s.registry.Remove(key) err = req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) @@ -225,10 +226,29 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind return } + key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} + + if !s.registry.Register(key, s) { + log.Printf("Failed to register TCP client with id: %s", key.Id) + if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { + log.Printf("Failed to reset port status: %v", setErr) + } + if closeErr := listener.Close(); closeErr != nil { + log.Printf("Failed to close listener: %s", closeErr) + } + err = req.Reply(false, nil) + if err != nil { + log.Println("Failed to reply to request:", err) + } + _ = s.lifecycle.Close() + return + } + buf := new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { log.Println("Failed to write port to buffer:", err) + s.registry.Remove(key) if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } @@ -244,6 +264,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) + s.registry.Remove(key) if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } @@ -258,6 +279,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.forwarder.SetType(types.TCP) s.forwarder.SetListener(listener) s.forwarder.SetForwardedPort(portToBind) + s.slugManager.Set(key.Id) s.lifecycle.SetStatus(types.RUNNING) go s.forwarder.AcceptTCPConnections() s.interaction.Start() diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index f2dea50..ea96081 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -25,7 +25,7 @@ type Lifecycle struct { channel ssh.Channel forwarder Forwarder slugManager slug.Manager - unregisterClient func(slug string) + unregisterClient func(key types.SessionKey) startedAt time.Time } @@ -41,7 +41,7 @@ func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) } } -func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { +func (l *Lifecycle) SetUnregisterClient(unregisterClient func(key types.SessionKey)) { l.unregisterClient = unregisterClient } @@ -51,7 +51,7 @@ type SessionLifecycle interface { GetConnection() ssh.Conn GetChannel() ssh.Channel SetChannel(channel ssh.Channel) - SetUnregisterClient(unregisterClient func(slug string)) + SetUnregisterClient(unregisterClient func(key types.SessionKey)) IsActive() bool StartedAt() time.Time } @@ -94,8 +94,9 @@ func (l *Lifecycle) Close() error { } clientSlug := l.slugManager.Get() - if clientSlug != "" { - l.unregisterClient(clientSlug) + if clientSlug != "" && l.unregisterClient != nil { + key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()} + l.unregisterClient(key) } if l.forwarder.GetTunnelType() == types.TCP { diff --git a/session/registry.go b/session/registry.go index b6f655d..6a607f9 100644 --- a/session/registry.go +++ b/session/registry.go @@ -3,96 +3,109 @@ package session import ( "fmt" "sync" + "tunnel_pls/types" ) +type Key = types.SessionKey + type Registry interface { - Get(slug string) (session *SSHSession, err error) - Update(oldSlug, newSlug string) error - Register(slug string, session *SSHSession) (success bool) - Remove(slug string) + Get(key Key) (session *SSHSession, err error) + Update(oldKey, newKey Key) error + Register(key Key, session *SSHSession) (success bool) + Remove(key Key) GetAllSessionFromUser(user string) []*SSHSession } type registry struct { mu sync.RWMutex - byUser map[string]map[string]*SSHSession - slugIndex map[string]string + byUser map[string]map[Key]*SSHSession + slugIndex map[Key]string } func NewRegistry() Registry { return ®istry{ - byUser: make(map[string]map[string]*SSHSession), - slugIndex: make(map[string]string), + byUser: make(map[string]map[Key]*SSHSession), + slugIndex: make(map[Key]string), } } -func (r *registry) Get(slug string) (session *SSHSession, err error) { +func (r *registry) Get(key Key) (session *SSHSession, err error) { r.mu.RLock() defer r.mu.RUnlock() - userID, ok := r.slugIndex[slug] + userID, ok := r.slugIndex[key] if !ok { return nil, fmt.Errorf("session not found") } - client, ok := r.byUser[userID][slug] + client, ok := r.byUser[userID][key] if !ok { return nil, fmt.Errorf("session not found") } return client, nil } -func (r *registry) Update(oldSlug, newSlug string) error { - if isForbiddenSlug(newSlug) { +func (r *registry) Update(oldKey, newKey Key) error { + if oldKey.Type != newKey.Type { + return fmt.Errorf("tunnel type cannot change") + } + + if newKey.Type != types.HTTP { + return fmt.Errorf("non http tunnel cannot change slug") + } + + if isForbiddenSlug(newKey.Id) { return fmt.Errorf("this subdomain is reserved. Please choose a different one") - } else if !isValidSlug(newSlug) { + } + + if !isValidSlug(newKey.Id) { return fmt.Errorf("invalid subdomain. Follow the rules") } r.mu.Lock() defer r.mu.Unlock() - userID, ok := r.slugIndex[oldSlug] + userID, ok := r.slugIndex[oldKey] if !ok { return fmt.Errorf("session not found") } - if _, exists := r.slugIndex[newSlug]; exists && newSlug != oldSlug { + if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { return fmt.Errorf("someone already uses this subdomain") } - client, ok := r.byUser[userID][oldSlug] + client, ok := r.byUser[userID][oldKey] if !ok { return fmt.Errorf("session not found") } - delete(r.byUser[userID], oldSlug) - delete(r.slugIndex, oldSlug) + delete(r.byUser[userID], oldKey) + delete(r.slugIndex, oldKey) - client.slugManager.Set(newSlug) - r.slugIndex[newSlug] = userID + client.slugManager.Set(newKey.Id) + r.slugIndex[newKey] = userID if r.byUser[userID] == nil { - r.byUser[userID] = make(map[string]*SSHSession) + r.byUser[userID] = make(map[Key]*SSHSession) } - r.byUser[userID][newSlug] = client + r.byUser[userID][newKey] = client return nil } -func (r *registry) Register(slug string, session *SSHSession) (success bool) { +func (r *registry) Register(key Key, session *SSHSession) (success bool) { r.mu.Lock() defer r.mu.Unlock() - if _, exists := r.slugIndex[slug]; exists { + if _, exists := r.slugIndex[key]; exists { return false } userID := session.userID if r.byUser[userID] == nil { - r.byUser[userID] = make(map[string]*SSHSession) + r.byUser[userID] = make(map[Key]*SSHSession) } - r.byUser[userID][slug] = session - r.slugIndex[slug] = userID + r.byUser[userID][key] = session + r.slugIndex[key] = userID return true } @@ -112,20 +125,20 @@ func (r *registry) GetAllSessionFromUser(user string) []*SSHSession { return sessions } -func (r *registry) Remove(slug string) { +func (r *registry) Remove(key Key) { r.mu.Lock() defer r.mu.Unlock() - userID, ok := r.slugIndex[slug] + userID, ok := r.slugIndex[key] if !ok { return } - delete(r.byUser[userID], slug) + delete(r.byUser[userID], key) if len(r.byUser[userID]) == 0 { delete(r.byUser, userID) } - delete(r.slugIndex, slug) + delete(r.slugIndex, key) } func isValidSlug(slug string) bool { diff --git a/session/session.go b/session/session.go index 45d497a..d406160 100644 --- a/session/session.go +++ b/session/session.go @@ -9,6 +9,7 @@ import ( "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) @@ -54,9 +55,15 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager) interactionManager.SetLifecycle(lifecycleManager) - interactionManager.SetSlugModificator(sessionRegistry.Update) + interactionManager.SetSlugModificator(func(oldSlug, newSlug string) error { + oldKey := types.SessionKey{Id: oldSlug, Type: forwarderManager.GetTunnelType()} + newKey := types.SessionKey{Id: newSlug, Type: forwarderManager.GetTunnelType()} + return sessionRegistry.Update(oldKey, newKey) + }) forwarderManager.SetLifecycle(lifecycleManager) - lifecycleManager.SetUnregisterClient(sessionRegistry.Remove) + lifecycleManager.SetUnregisterClient(func(key types.SessionKey) { + sessionRegistry.Remove(key) + }) return &SSHSession{ initialReq: forwardingReq, diff --git a/types/types.go b/types/types.go index f909da5..5c4eece 100644 --- a/types/types.go +++ b/types/types.go @@ -15,6 +15,11 @@ const ( TCP TunnelType = "TCP" ) +type SessionKey struct { + Id string + Type TunnelType +} + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" +