revert-54069ad305 #11
@@ -0,0 +1,361 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetenv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
val string
|
||||
def string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "returns existing env",
|
||||
key: "TEST_ENV_EXIST",
|
||||
val: "value",
|
||||
def: "default",
|
||||
expected: "value",
|
||||
},
|
||||
{
|
||||
name: "returns default when env missing",
|
||||
key: "TEST_ENV_MISSING",
|
||||
val: "",
|
||||
def: "default",
|
||||
expected: "default",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv(tt.key, tt.val)
|
||||
} else {
|
||||
os.Unsetenv(tt.key)
|
||||
}
|
||||
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetenvBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
val string
|
||||
def bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "returns true when env is true",
|
||||
key: "TEST_BOOL_TRUE",
|
||||
val: "true",
|
||||
def: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "returns false when env is false",
|
||||
key: "TEST_BOOL_FALSE",
|
||||
val: "false",
|
||||
def: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "returns default when env missing",
|
||||
key: "TEST_BOOL_MISSING",
|
||||
val: "",
|
||||
def: true,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "returns false when env is not true",
|
||||
key: "TEST_BOOL_INVALID",
|
||||
val: "yes",
|
||||
def: true,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv(tt.key, tt.val)
|
||||
} else {
|
||||
os.Unsetenv(tt.key)
|
||||
}
|
||||
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
expect types.ServerMode
|
||||
expectErr bool
|
||||
}{
|
||||
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
|
||||
{"node", "node", types.ServerModeNODE, false},
|
||||
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
|
||||
{"invalid", "invalid", 0, true},
|
||||
{"empty (default)", "", types.ServerModeSTANDALONE, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.mode != "" {
|
||||
t.Setenv("MODE", tt.mode)
|
||||
} else {
|
||||
os.Unsetenv("MODE")
|
||||
}
|
||||
mode, err := parseMode()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expect, mode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAllowedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
start uint16
|
||||
end uint16
|
||||
expectErr bool
|
||||
}{
|
||||
{"valid range", "1000-2000", 1000, 2000, false},
|
||||
{"empty", "", 0, 0, false},
|
||||
{"invalid format - no dash", "1000", 0, 0, true},
|
||||
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
|
||||
{"invalid start port", "abc-2000", 0, 0, true},
|
||||
{"invalid end port", "1000-abc", 0, 0, true},
|
||||
{"out of range start", "70000-80000", 0, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("ALLOWED_PORTS", tt.val)
|
||||
} else {
|
||||
os.Unsetenv("ALLOWED_PORTS")
|
||||
}
|
||||
start, end, err := parseAllowedPorts()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.start, start)
|
||||
assert.Equal(t, tt.end, end)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBufferSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
expect int
|
||||
}{
|
||||
{"valid size", "8192", 8192},
|
||||
{"default size", "", 32768},
|
||||
{"too small", "1024", 4096},
|
||||
{"too large", "2000000", 4096},
|
||||
{"invalid format", "abc", 4096},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("BUFFER_SIZE", tt.val)
|
||||
} else {
|
||||
os.Unsetenv("BUFFER_SIZE")
|
||||
}
|
||||
size := parseBufferSize()
|
||||
assert.Equal(t, tt.expect, size)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envs map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "minimal valid config",
|
||||
envs: map[string]string{
|
||||
"DOMAIN": "example.com",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "TLS enabled without token",
|
||||
envs: map[string]string{
|
||||
"TLS_ENABLED": "true",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "TLS enabled with token",
|
||||
envs: map[string]string{
|
||||
"TLS_ENABLED": "true",
|
||||
"CF_API_TOKEN": "secret",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Node mode without token",
|
||||
envs: map[string]string{
|
||||
"MODE": "node",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Node mode with token",
|
||||
envs: map[string]string{
|
||||
"MODE": "node",
|
||||
"NODE_TOKEN": "token",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid mode",
|
||||
envs: map[string]string{
|
||||
"MODE": "invalid",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid allowed ports",
|
||||
envs: map[string]string{
|
||||
"ALLOWED_PORTS": "1000",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
for k, v := range tt.envs {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
cfg, err := parse()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetters(t *testing.T) {
|
||||
envs := map[string]string{
|
||||
"DOMAIN": "example.com",
|
||||
"PORT": "2222",
|
||||
"HTTP_PORT": "80",
|
||||
"HTTPS_PORT": "443",
|
||||
"TLS_ENABLED": "true",
|
||||
"TLS_REDIRECT": "true",
|
||||
"ACME_EMAIL": "test@example.com",
|
||||
"CF_API_TOKEN": "token",
|
||||
"ACME_STAGING": "true",
|
||||
"ALLOWED_PORTS": "1000-2000",
|
||||
"BUFFER_SIZE": "16384",
|
||||
"PPROF_ENABLED": "true",
|
||||
"PPROF_PORT": "7070",
|
||||
"MODE": "standalone",
|
||||
"GRPC_ADDRESS": "127.0.0.1",
|
||||
"GRPC_PORT": "9090",
|
||||
"NODE_TOKEN": "ntoken",
|
||||
}
|
||||
|
||||
os.Clearenv()
|
||||
for k, v := range envs {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
cfg, err := parse()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "example.com", cfg.Domain())
|
||||
assert.Equal(t, "2222", cfg.SSHPort())
|
||||
assert.Equal(t, "80", cfg.HTTPPort())
|
||||
assert.Equal(t, "443", cfg.HTTPSPort())
|
||||
assert.Equal(t, true, cfg.TLSEnabled())
|
||||
assert.Equal(t, true, cfg.TLSRedirect())
|
||||
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
|
||||
assert.Equal(t, "token", cfg.CFAPIToken())
|
||||
assert.Equal(t, true, cfg.ACMEStaging())
|
||||
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
|
||||
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
|
||||
assert.Equal(t, 16384, cfg.BufferSize())
|
||||
assert.Equal(t, true, cfg.PprofEnabled())
|
||||
assert.Equal(t, "7070", cfg.PprofPort())
|
||||
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
|
||||
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
|
||||
assert.Equal(t, "9090", cfg.GRPCPort())
|
||||
assert.Equal(t, "ntoken", cfg.NodeToken())
|
||||
}
|
||||
|
||||
func TestMustLoad(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
t.Setenv("DOMAIN", "example.com")
|
||||
cfg, err := MustLoad()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("loadEnvFile error", func(t *testing.T) {
|
||||
err := os.Mkdir(".env", 0755)
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(".env")
|
||||
|
||||
cfg, err := MustLoad()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("parse error", func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
t.Setenv("MODE", "invalid")
|
||||
cfg, err := MustLoad()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
t.Run("file exists", func(t *testing.T) {
|
||||
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(".env")
|
||||
|
||||
err = loadEnvFile()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
|
||||
})
|
||||
|
||||
t.Run("file missing", func(t *testing.T) {
|
||||
_ = os.Remove(".env")
|
||||
err := loadEnvFile()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user