refactor(config): centralize env loading and enforce typed access

- Centralize environment variable loading in config.MustLoad
- Parse and validate all env vars once at initialization
- Make config fields private and read-only
- Remove public Getenv usage in favor of typed accessors
- Improve validation and initialization order
- Normalize enum naming to be idiomatic and avoid constant collisions
This commit is contained in:
2026-01-21 19:43:19 +07:00
parent 1e12373359
commit 2bc20dd991
19 changed files with 414 additions and 257 deletions
+31 -32
View File
@@ -18,37 +18,6 @@ import (
"golang.org/x/crypto/ssh"
)
var bufferPool = sync.Pool{
New: func() interface{} {
bufSize := config.GetBufferSize()
return make([]byte, bufSize)
},
}
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
}
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.UNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
}
}
type Forwarder interface {
SetType(tunnelType types.TunnelType)
SetForwardedPort(port uint16)
@@ -62,6 +31,36 @@ type Forwarder interface {
WriteBadGatewayResponse(dst io.Writer)
Close() error
}
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
bufferPool sync.Pool
}
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.TunnelTypeUNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
return make([]byte, bufSize)
},
},
}
}
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().([]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
type channelResult struct {
@@ -107,7 +106,7 @@ func closeWriter(w io.Writer) error {
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error
_, err := copyWithBuffer(dst, src)
_, err := f.copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
}
+11 -10
View File
@@ -18,9 +18,9 @@ import (
)
type Interaction interface {
Mode() types.Mode
Mode() types.InteractiveMode
SetChannel(channel ssh.Channel)
SetMode(m types.Mode)
SetMode(m types.InteractiveMode)
SetWH(w, h int)
Start()
Redraw()
@@ -39,6 +39,7 @@ type Forwarder interface {
type CloseFunc func() error
type interaction struct {
config config.Config
channel ssh.Channel
slug slug.Slug
forwarder Forwarder
@@ -48,14 +49,14 @@ type interaction struct {
program *tea.Program
ctx context.Context
cancel context.CancelFunc
mode types.Mode
mode types.InteractiveMode
}
func (i *interaction) SetMode(m types.Mode) {
func (i *interaction) SetMode(m types.InteractiveMode) {
i.mode = m
}
func (i *interaction) Mode() types.Mode {
func (i *interaction) Mode() types.InteractiveMode {
return i.mode
}
@@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) {
}
}
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background())
return &interaction{
config: config,
channel: nil,
slug: slug,
forwarder: forwarder,
@@ -174,14 +176,13 @@ func (m *model) View() string {
}
func (i *interaction) Start() {
if i.mode == types.HEADLESS {
if i.mode == types.InteractiveModeHEADLESS {
return
}
lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost")
protocol := "http"
if config.Getenv("TLS_ENABLED", "false") == "true" {
if i.config.TLSEnabled() {
protocol = "https"
}
@@ -209,7 +210,7 @@ func (i *interaction) Start() {
ti.Width = 50
m := &model{
domain: domain,
domain: i.config.Domain(),
protocol: protocol,
tunnelType: tunnelType,
port: port,
+1 -1
View File
@@ -41,7 +41,7 @@ type model struct {
}
func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP {
if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
+6 -6
View File
@@ -15,7 +15,7 @@ import (
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
if m.tunnelType != types.HTTP {
if m.tunnelType != types.TunnelTypeHTTP {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
@@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.HTTP,
Type: types.TunnelTypeHTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
Type: types.TunnelTypeHTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
@@ -130,7 +130,7 @@ func (m *model) slugView() string {
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.HTTP {
if m.tunnelType != types.TunnelTypeHTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
@@ -145,9 +145,9 @@ func (m *model) slugView() string {
var warningText string
if isVeryCompact {
warningText = "⚠️ TCP tunnels don't support custom subdomains."
warningText = "⚠️ TunnelTypeTCP tunnels don't support custom subdomains."
} else {
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
warningText = "⚠️ TunnelTypeTCP tunnels cannot have custom subdomains. Only TunnelTypeHTTP/HTTPS tunnels support subdomain customization."
}
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
+7 -7
View File
@@ -24,7 +24,7 @@ type SessionRegistry interface {
}
type lifecycle struct {
status types.Status
status types.SessionStatus
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
@@ -37,7 +37,7 @@ type lifecycle struct {
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{
status: types.INITIALIZING,
status: types.SessionStatusINITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
@@ -54,7 +54,7 @@ type Lifecycle interface {
PortRegistry() portUtil.Port
User() string
SetChannel(channel ssh.Channel)
SetStatus(status types.Status)
SetStatus(status types.SessionStatus)
IsActive() bool
StartedAt() time.Time
Close() error
@@ -74,9 +74,9 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) {
func (l *lifecycle) Connection() ssh.Conn {
return l.conn
}
func (l *lifecycle) SetStatus(status types.Status) {
func (l *lifecycle) SetStatus(status types.SessionStatus) {
l.status = status
if status == types.RUNNING && l.startedAt.IsZero() {
if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now()
}
}
@@ -112,7 +112,7 @@ func (l *lifecycle) Close() error {
}
l.sessionRegistry.Remove(key)
if tunnelType == types.TCP {
if tunnelType == types.TunnelTypeTCP {
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
firstErr = err
}
@@ -122,7 +122,7 @@ func (l *lifecycle) Close() error {
}
func (l *lifecycle) IsActive() bool {
return l.status == types.RUNNING
return l.status == types.SessionStatusRUNNING
}
func (l *lifecycle) StartedAt() time.Time {
+19 -17
View File
@@ -37,6 +37,7 @@ type Session interface {
}
type session struct {
config config.Config
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle
@@ -48,13 +49,14 @@ type session struct {
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(slugManager, conn)
forwarderManager := forwarder.New(config, slugManager, conn)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
return &session{
config: config,
initialReq: initialReq,
sshChan: sshChan,
lifecycle: lifecycleManager,
@@ -83,12 +85,12 @@ func (s *session) Slug() slug.Slug {
func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{
types.HTTP: "HTTP",
types.TCP: "TCP",
types.TunnelTypeHTTP: "TunnelTypeHTTP",
types.TunnelTypeTCP: "TunnelTypeTCP",
}
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok {
tunnelType = "UNKNOWN"
tunnelType = "TunnelTypeUNKNOWN"
}
return &types.Detail{
@@ -131,7 +133,7 @@ func (s *session) setupSessionMode() error {
}
return s.setupInteractiveMode(channel)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
s.interaction.SetMode(types.InteractiveModeHEADLESS)
return nil
}
}
@@ -152,13 +154,13 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
return nil
}
func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain, s.config.SSHPort))
if err != nil {
return err
}
@@ -169,8 +171,8 @@ func (s *session) handleMissingForwardRequest() error {
}
func (s *session) shouldRejectUnauthorized() bool {
return s.interaction.Mode() == types.HEADLESS &&
config.Getenv("MODE", "standalone") == "standalone" &&
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
s.config.Mode() == types.ServerModeSTANDALONE &&
s.lifecycle.User() == "UNAUTHORIZED"
}
@@ -318,7 +320,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
s.forwarder.SetType(tunnelType)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug)
s.lifecycle.SetStatus(types.RUNNING)
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
if listener != nil {
s.forwarder.SetListener(listener)
@@ -348,12 +350,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
}
key := types.SessionKey{Id: randomString, Type: types.HTTP}
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
}
err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id)
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
@@ -371,12 +373,12 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id))
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
}
err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id)
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}