From 1171b183406a4ee88ceb19cac5b798fe34eb72d9 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 23 Jan 2026 23:51:58 +0700 Subject: [PATCH] refactor: decouple application startup logic from main --- internal/bootstrap/bootstrap.go | 198 ++++++++++++++++++++++++++++++++ internal/config/config.go | 3 + internal/config/loader.go | 5 + main.go | 161 +------------------------- server/server.go | 1 - server/server_test.go | 1 + session/lifecycle/lifecycle.go | 8 +- session/session.go | 1 - 8 files changed, 214 insertions(+), 164 deletions(-) create mode 100644 internal/bootstrap/bootstrap.go diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go new file mode 100644 index 0000000..77ed6e8 --- /dev/null +++ b/internal/bootstrap/bootstrap.go @@ -0,0 +1,198 @@ +package bootstrap + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/grpc/client" + "tunnel_pls/internal/key" + "tunnel_pls/internal/port" + "tunnel_pls/internal/random" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" + "tunnel_pls/internal/version" + "tunnel_pls/server" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type Bootstrap struct { + Randomizer random.Random + Config config.Config + SessionRegistry registry.Registry + Port port.Port +} + +func New() (*Bootstrap, error) { + conf, err := config.MustLoad() + if err != nil { + return nil, err + } + + randomizer := random.New() + sessionRegistry := registry.NewRegistry() + + portManager := port.New() + if err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()); err != nil { + return nil, err + } + + return &Bootstrap{ + Randomizer: randomizer, + Config: conf, + SessionRegistry: sessionRegistry, + Port: portManager, + }, nil +} + +func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) { + sshCfg := &ssh.ServerConfig{ + NoClientAuth: true, + ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), + } + + if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { + return nil, fmt.Errorf("generate ssh key: %w", err) + } + privateBytes, err := os.ReadFile(sshKeyPath) + if err != nil { + return nil, fmt.Errorf("read private key: %w", err) + } + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + sshCfg.AddHostKey(private) + return sshCfg, nil +} + +func startGRPCClient(ctx context.Context, conf config.Config, registry registry.Registry, errChan chan<- error) (client.Client, error) { + grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort()) + grpcClient, err := client.New(conf, grpcAddr, registry) + if err != nil { + return nil, err + } + healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) + defer healthCancel() + if err = grpcClient.CheckServerHealth(healthCtx); err != nil { + return nil, fmt.Errorf("gRPC health check failed: %w", err) + } + + go func() { + if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { + errChan <- fmt.Errorf("failed to subscribe to events: %w", err) + } + }() + + return grpcClient, nil +} + +func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) { + httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), registry, conf.TLSRedirect()) + ln, err := httpserver.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return + } + if err = httpserver.Serve(ln); err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + } +} + +func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) { + tlsCfg, err := transport.NewTLSConfig(conf) + if err != nil { + errChan <- fmt.Errorf("failed to create TLS config: %w", err) + return + } + httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg) + ln, err := httpsServer.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start https server: %w", err) + return + } + if err = httpsServer.Serve(ln); err != nil { + errChan <- fmt.Errorf("error when serving https server: %w", err) + } +} + +func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, sshPort string) error { + sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, sshPort) + if err != nil { + return err + } + + sshServer.Start() + + return sshServer.Close() +} + +func startPprof(pprofPort string) { + pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) + log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) + if err := http.ListenAndServe(pprofAddr, nil); err != nil { + log.Printf("pprof server error: %v", err) + } +} + +func (b *Bootstrap) Run() error { + sshConfig, err := newSSHConfig(b.Config.KeyLoc()) + if err != nil { + return fmt.Errorf("failed to create SSH config: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errChan := make(chan error, 5) + shutdownChan := make(chan os.Signal, 1) + signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) + + var grpcClient client.Client + if b.Config.Mode() == types.ServerModeNODE { + grpcClient, err = startGRPCClient(ctx, b.Config, b.SessionRegistry, errChan) + if err != nil { + return fmt.Errorf("failed to start gRPC client: %w", err) + } + defer func(grpcClient client.Client) { + err = grpcClient.Close() + if err != nil { + log.Printf("failed to close gRPC client") + } + }(grpcClient) + } + + go startHTTPServer(b.Config, b.SessionRegistry, errChan) + + if b.Config.TLSEnabled() { + go startHTTPSServer(b.Config, b.SessionRegistry, errChan) + } + + go func() { + if err = startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, grpcClient, b.Port, b.Config.SSHPort()); err != nil { + errChan <- fmt.Errorf("SSH server error: %w", err) + } + }() + + if b.Config.PprofEnabled() { + go startPprof(b.Config.PprofPort()) + } + + log.Println("All services started successfully") + + select { + case err = <-errChan: + return fmt.Errorf("service error: %w", err) + case sig := <-shutdownChan: + log.Printf("Received signal %s, initiating graceful shutdown", sig) + cancel() + return nil + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 62e1aca..5c21abf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,6 +9,8 @@ type Config interface { HTTPPort() string HTTPSPort() string + KeyLoc() string + TLSEnabled() bool TLSRedirect() bool @@ -47,6 +49,7 @@ func (c *config) Domain() string { return c.domain } func (c *config) SSHPort() string { return c.sshPort } func (c *config) HTTPPort() string { return c.httpPort } func (c *config) HTTPSPort() string { return c.httpsPort } +func (c *config) KeyLoc() string { return c.keyLoc } func (c *config) TLSEnabled() bool { return c.tlsEnabled } func (c *config) TLSRedirect() bool { return c.tlsRedirect } func (c *config) ACMEEmail() string { return c.acmeEmail } diff --git a/internal/config/loader.go b/internal/config/loader.go index cde9fd0..ebccf3f 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -18,6 +18,8 @@ type config struct { httpPort string httpsPort string + keyLoc string + tlsEnabled bool tlsRedirect bool @@ -51,6 +53,8 @@ func parse() (*config, error) { httpPort := getenv("HTTP_PORT", "8080") httpsPort := getenv("HTTPS_PORT", "8443") + keyLoc := getenv("KEY_LOC", "certs/privkey.pem") + tlsEnabled := getenvBool("TLS_ENABLED", false) tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false) @@ -85,6 +89,7 @@ func parse() (*config, error) { sshPort: sshPort, httpPort: httpPort, httpsPort: httpsPort, + keyLoc: keyLoc, tlsEnabled: tlsEnabled, tlsRedirect: tlsRedirect, acmeEmail: acmeEmail, diff --git a/main.go b/main.go index d62f722..be8b510 100644 --- a/main.go +++ b/main.go @@ -1,28 +1,11 @@ package main import ( - "context" "fmt" "log" - "net" - "net/http" - _ "net/http/pprof" "os" - "os/signal" - "syscall" - "time" - "tunnel_pls/internal/config" - "tunnel_pls/internal/grpc/client" - "tunnel_pls/internal/key" - "tunnel_pls/internal/port" - "tunnel_pls/internal/random" - "tunnel_pls/internal/registry" - "tunnel_pls/internal/transport" + "tunnel_pls/internal/bootstrap" "tunnel_pls/internal/version" - "tunnel_pls/server" - "tunnel_pls/types" - - "golang.org/x/crypto/ssh" ) func main() { @@ -33,148 +16,14 @@ func main() { log.SetOutput(os.Stdout) log.SetFlags(log.LstdFlags | log.Lshortfile) - log.Printf("Starting %s", version.GetVersion()) - conf, err := config.MustLoad() + boot, err := bootstrap.New() if err != nil { - log.Fatalf("Failed to load configuration: %s", err) - return + log.Fatalf("Startup error: %v", err) } - sshConfig := &ssh.ServerConfig{ - NoClientAuth: true, - ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), - } - - sshKeyPath := "certs/ssh/id_rsa" - if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { - log.Fatalf("Failed to generate SSH key: %s", err) - } - - privateBytes, err := os.ReadFile(sshKeyPath) - if err != nil { - log.Fatalf("Failed to load private key: %s", err) - } - - private, err := ssh.ParsePrivateKey(privateBytes) - if err != nil { - log.Fatalf("Failed to parse private key: %s", err) - } - - sshConfig.AddHostKey(private) - sessionRegistry := registry.NewRegistry() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errChan := make(chan error, 2) - shutdownChan := make(chan os.Signal, 1) - signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) - - var grpcClient client.Client - - if conf.Mode() == types.ServerModeNODE { - grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort()) - - grpcClient, err = client.New(conf, grpcAddr, sessionRegistry) - if err != nil { - log.Fatalf("failed to create grpc client: %v", err) - } - - healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) - if err = grpcClient.CheckServerHealth(healthCtx); err != nil { - healthCancel() - log.Fatalf("gRPC health check failed: %v", err) - } - healthCancel() - - go func() { - if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { - errChan <- fmt.Errorf("failed to subscribe to events: %w", err) - } - }() - } - - go func() { - var httpListener net.Listener - httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect()) - httpListener, err = httpserver.Listen() - if err != nil { - errChan <- fmt.Errorf("failed to start http server: %w", err) - return - } - err = httpserver.Serve(httpListener) - if err != nil { - errChan <- fmt.Errorf("error when serving http server: %w", err) - return - } - }() - - if conf.TLSEnabled() { - go func() { - var httpsListener net.Listener - tlsConfig, _ := transport.NewTLSConfig(conf) - httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig) - httpsListener, err = httpsServer.Listen() - if err != nil { - errChan <- fmt.Errorf("failed to start http server: %w", err) - return - } - err = httpsServer.Serve(httpsListener) - if err != nil { - errChan <- fmt.Errorf("error when serving http server: %w", err) - return - } - }() - } - - portManager := port.New() - err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()) - if err != nil { - log.Fatalf("Failed to initialize port manager: %s", err) - return - } - randomizer := random.New() - var app server.Server - go func() { - app, err = server.New(randomizer, conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort()) - if err != nil { - errChan <- fmt.Errorf("failed to start server: %s", err) - return - } - app.Start() - - }() - - if conf.PprofEnabled() { - go func() { - pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort()) - log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) - if err = http.ListenAndServe(pprofAddr, nil); err != nil { - log.Printf("pprof server error: %v", err) - } - }() - } - - select { - case err = <-errChan: - log.Printf("error happen : %s", err) - case sig := <-shutdownChan: - log.Printf("received signal %s, shutting down", sig) - } - - cancel() - - if app != nil { - if err = app.Close(); err != nil { - log.Printf("failed to close server : %s", err) - } - } - - if grpcClient != nil { - if err = grpcClient.Close(); err != nil { - log.Printf("failed to close grpc conn : %s", err) - } + if err = boot.Run(); err != nil { + log.Fatalf("Application error: %v", err) } } diff --git a/server/server.go b/server/server.go index 0538d69..d3df5fd 100644 --- a/server/server.go +++ b/server/server.go @@ -114,5 +114,4 @@ func (s *server) handleConnection(conn net.Conn) { log.Printf("SSH session ended with error: %s", err.Error()) return } - return } diff --git a/server/server_test.go b/server/server_test.go index d01479e..de54f18 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -49,6 +49,7 @@ func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(type func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) } func (m *mockConfig) GRPCPort() string { return m.Called().String(0) } func (m *mockConfig) NodeToken() string { return m.Called().String(0) } +func (m *mockConfig) KeyLoc() string { return m.Called().String(0) } type mockRegistry struct { mock.Mock diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index a775c22..234bff8 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -123,12 +123,8 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TunnelTypeTCP { - if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil { - errs = append(errs, err) - } - if err := l.forwarder.Close(); err != nil { - errs = append(errs, err) - } + errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false)) + errs = append(errs, l.forwarder.Close()) } l.closeErr = errors.Join(errs...) diff --git a/session/session.go b/session/session.go index 926aebf..e5d4cc2 100644 --- a/session/session.go +++ b/session/session.go @@ -195,7 +195,6 @@ func (s *session) waitForSessionEnd() error { } if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) return err } return nil