From 9f4c24a3f37a47933db0699eea1b3ce4f2af6277 Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 21 Jan 2026 21:55:38 +0700 Subject: [PATCH] refactor(lifecycle): reorder resource closing and simplify Close() - Close channel and connection first, then remove session - Close forwarded port and forwarder at the end for TCP tunnels - Aggregate all errors using errors.Join instead of failing early --- session/lifecycle/lifecycle.go | 40 +++++++++++++++------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index e4ce44f..f9f9d6e 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -2,8 +2,6 @@ package lifecycle import ( "errors" - "io" - "net" "time" portUtil "tunnel_pls/internal/port" @@ -81,28 +79,23 @@ func (l *lifecycle) SetStatus(status types.SessionStatus) { } } +func closeIfNotNil(c interface{ Close() error }) error { + if c != nil { + return c.Close() + } + return nil +} + func (l *lifecycle) Close() error { - var firstErr error + var errs []error tunnelType := l.forwarder.TunnelType() - if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - firstErr = err + if err := closeIfNotNil(l.channel); err != nil { + errs = append(errs, err) } - if l.channel != nil { - if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - if firstErr == nil { - firstErr = err - } - } - } - - if l.conn != nil { - if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - if firstErr == nil { - firstErr = err - } - } + if err := closeIfNotNil(l.conn); err != nil { + errs = append(errs, err) } clientSlug := l.slug.String() @@ -113,12 +106,15 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TunnelTypeTCP { - if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { - firstErr = err + if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil { + errs = append(errs, err) + } + if err := l.forwarder.Close(); err != nil { + errs = append(errs, err) } } - return firstErr + return errors.Join(errs...) } func (l *lifecycle) IsActive() bool {