Files
tunnel-please/internal/config/config_test.go

406 lines
8.7 KiB
Go

package config
import (
"os"
"testing"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
)
func TestGetenv(t *testing.T) {
tests := []struct {
name string
key string
val string
def string
expected string
}{
{
name: "returns existing env",
key: "TEST_ENV_EXIST",
val: "value",
def: "default",
expected: "value",
},
{
name: "returns default when env missing",
key: "TEST_ENV_MISSING",
val: "",
def: "default",
expected: "default",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
})
}
}
func TestGetenvBool(t *testing.T) {
tests := []struct {
name string
key string
val string
def bool
expected bool
}{
{
name: "returns true when env is true",
key: "TEST_BOOL_TRUE",
val: "true",
def: false,
expected: true,
},
{
name: "returns false when env is false",
key: "TEST_BOOL_FALSE",
val: "false",
def: true,
expected: false,
},
{
name: "returns default when env missing",
key: "TEST_BOOL_MISSING",
val: "",
def: true,
expected: true,
},
{
name: "returns false when env is not true",
key: "TEST_BOOL_INVALID",
val: "yes",
def: true,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
})
}
}
func TestParseMode(t *testing.T) {
tests := []struct {
name string
mode string
expect types.ServerMode
expectErr bool
}{
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
{"node", "node", types.ServerModeNODE, false},
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
{"invalid", "invalid", 0, true},
{"empty (default)", "", types.ServerModeSTANDALONE, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mode != "" {
t.Setenv("MODE", tt.mode)
} else {
err := os.Unsetenv("MODE")
assert.NoError(t, err)
}
mode, err := parseMode()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expect, mode)
}
})
}
}
func TestParseAllowedPorts(t *testing.T) {
tests := []struct {
name string
val string
start uint16
end uint16
expectErr bool
}{
{"valid range", "1000-2000", 1000, 2000, false},
{"empty", "", 0, 0, false},
{"invalid format - no dash", "1000", 0, 0, true},
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
{"invalid start port", "abc-2000", 0, 0, true},
{"invalid end port", "1000-abc", 0, 0, true},
{"out of range start", "70000-80000", 0, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("ALLOWED_PORTS", tt.val)
} else {
err := os.Unsetenv("ALLOWED_PORTS")
assert.NoError(t, err)
}
start, end, err := parseAllowedPorts()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.start, start)
assert.Equal(t, tt.end, end)
}
})
}
}
func TestParseBufferSize(t *testing.T) {
tests := []struct {
name string
val string
expect int
}{
{"valid size", "8192", 8192},
{"default size", "", 32768},
{"too small", "1024", 4096},
{"too large", "2000000", 4096},
{"invalid format", "abc", 4096},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("BUFFER_SIZE", tt.val)
} else {
err := os.Unsetenv("BUFFER_SIZE")
assert.NoError(t, err)
}
size := parseBufferSize()
assert.Equal(t, tt.expect, size)
})
}
}
func TestParseHeaderSize(t *testing.T) {
tests := []struct {
name string
val string
expect int
}{
{"valid size", "8192", 8192},
{"default size", "", 4096},
{"too small", "1024", 4096},
{"too large", "2000000", 4096},
{"invalid format", "abc", 4096},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("MAX_HEADER_SIZE", tt.val)
} else {
err := os.Unsetenv("MAX_HEADER_SIZE")
assert.NoError(t, err)
}
size := parseHeaderSize()
assert.Equal(t, tt.expect, size)
})
}
}
func TestParse(t *testing.T) {
tests := []struct {
name string
envs map[string]string
expectErr bool
}{
{
name: "minimal valid config",
envs: map[string]string{
"DOMAIN": "example.com",
},
expectErr: false,
},
{
name: "TLS enabled without token",
envs: map[string]string{
"TLS_ENABLED": "true",
},
expectErr: true,
},
{
name: "TLS enabled with token",
envs: map[string]string{
"TLS_ENABLED": "true",
"CF_API_TOKEN": "secret",
},
expectErr: false,
},
{
name: "Node mode without token",
envs: map[string]string{
"MODE": "node",
},
expectErr: true,
},
{
name: "Node mode with token",
envs: map[string]string{
"MODE": "node",
"NODE_TOKEN": "token",
},
expectErr: false,
},
{
name: "invalid mode",
envs: map[string]string{
"MODE": "invalid",
},
expectErr: true,
},
{
name: "invalid allowed ports",
envs: map[string]string{
"ALLOWED_PORTS": "1000",
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Clearenv()
for k, v := range tt.envs {
t.Setenv(k, v)
}
cfg, err := parse()
if tt.expectErr {
assert.Error(t, err)
assert.Nil(t, cfg)
} else {
assert.NoError(t, err)
assert.NotNil(t, cfg)
}
})
}
}
func TestGetters(t *testing.T) {
envs := map[string]string{
"DOMAIN": "example.com",
"PORT": "2222",
"HTTP_PORT": "80",
"HTTPS_PORT": "443",
"KEY_LOC": "certs/ssh/id_rsa",
"TLS_ENABLED": "true",
"TLS_REDIRECT": "true",
"TLS_STORAGE_PATH": "certs/tls/",
"ACME_EMAIL": "test@example.com",
"CF_API_TOKEN": "token",
"ACME_STAGING": "true",
"ALLOWED_PORTS": "1000-2000",
"BUFFER_SIZE": "16384",
"MAX_HEADER_SIZE": "4096",
"PPROF_ENABLED": "true",
"PPROF_PORT": "7070",
"MODE": "standalone",
"GRPC_ADDRESS": "127.0.0.1",
"GRPC_PORT": "9090",
"NODE_TOKEN": "ntoken",
}
os.Clearenv()
for k, v := range envs {
t.Setenv(k, v)
}
cfg, err := parse()
assert.NoError(t, err)
assert.Equal(t, "example.com", cfg.Domain())
assert.Equal(t, "2222", cfg.SSHPort())
assert.Equal(t, "80", cfg.HTTPPort())
assert.Equal(t, "443", cfg.HTTPSPort())
assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
assert.Equal(t, true, cfg.TLSEnabled())
assert.Equal(t, true, cfg.TLSRedirect())
assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
assert.Equal(t, "token", cfg.CFAPIToken())
assert.Equal(t, true, cfg.ACMEStaging())
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
assert.Equal(t, 16384, cfg.BufferSize())
assert.Equal(t, 4096, cfg.HeaderSize())
assert.Equal(t, true, cfg.PprofEnabled())
assert.Equal(t, "7070", cfg.PprofPort())
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
assert.Equal(t, "9090", cfg.GRPCPort())
assert.Equal(t, "ntoken", cfg.NodeToken())
}
func TestMustLoad(t *testing.T) {
t.Run("success", func(t *testing.T) {
os.Clearenv()
t.Setenv("DOMAIN", "example.com")
cfg, err := MustLoad()
assert.NoError(t, err)
assert.NotNil(t, cfg)
})
t.Run("loadEnvFile error", func(t *testing.T) {
err := os.Mkdir(".env", 0755)
assert.NoError(t, err)
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
cfg, err := MustLoad()
assert.Error(t, err)
assert.Nil(t, cfg)
})
t.Run("parse error", func(t *testing.T) {
os.Clearenv()
t.Setenv("MODE", "invalid")
cfg, err := MustLoad()
assert.Error(t, err)
assert.Nil(t, cfg)
})
}
func TestLoadEnvFile(t *testing.T) {
t.Run("file exists", func(t *testing.T) {
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
assert.NoError(t, err)
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
err = loadEnvFile()
assert.NoError(t, err)
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
})
t.Run("file missing", func(t *testing.T) {
_ = os.Remove(".env")
err := loadEnvFile()
assert.NoError(t, err)
})
}