406 lines
8.7 KiB
Go
406 lines
8.7 KiB
Go
package config
|
|
|
|
import (
|
|
"os"
|
|
"testing"
|
|
"tunnel_pls/types"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestGetenv(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
key string
|
|
val string
|
|
def string
|
|
expected string
|
|
}{
|
|
{
|
|
name: "returns existing env",
|
|
key: "TEST_ENV_EXIST",
|
|
val: "value",
|
|
def: "default",
|
|
expected: "value",
|
|
},
|
|
{
|
|
name: "returns default when env missing",
|
|
key: "TEST_ENV_MISSING",
|
|
val: "",
|
|
def: "default",
|
|
expected: "default",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.val != "" {
|
|
t.Setenv(tt.key, tt.val)
|
|
} else {
|
|
err := os.Unsetenv(tt.key)
|
|
assert.NoError(t, err)
|
|
}
|
|
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetenvBool(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
key string
|
|
val string
|
|
def bool
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "returns true when env is true",
|
|
key: "TEST_BOOL_TRUE",
|
|
val: "true",
|
|
def: false,
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "returns false when env is false",
|
|
key: "TEST_BOOL_FALSE",
|
|
val: "false",
|
|
def: true,
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "returns default when env missing",
|
|
key: "TEST_BOOL_MISSING",
|
|
val: "",
|
|
def: true,
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "returns false when env is not true",
|
|
key: "TEST_BOOL_INVALID",
|
|
val: "yes",
|
|
def: true,
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.val != "" {
|
|
t.Setenv(tt.key, tt.val)
|
|
} else {
|
|
err := os.Unsetenv(tt.key)
|
|
assert.NoError(t, err)
|
|
}
|
|
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseMode(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
mode string
|
|
expect types.ServerMode
|
|
expectErr bool
|
|
}{
|
|
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
|
|
{"node", "node", types.ServerModeNODE, false},
|
|
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
|
|
{"invalid", "invalid", 0, true},
|
|
{"empty (default)", "", types.ServerModeSTANDALONE, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.mode != "" {
|
|
t.Setenv("MODE", tt.mode)
|
|
} else {
|
|
err := os.Unsetenv("MODE")
|
|
assert.NoError(t, err)
|
|
}
|
|
mode, err := parseMode()
|
|
if tt.expectErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expect, mode)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseAllowedPorts(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
val string
|
|
start uint16
|
|
end uint16
|
|
expectErr bool
|
|
}{
|
|
{"valid range", "1000-2000", 1000, 2000, false},
|
|
{"empty", "", 0, 0, false},
|
|
{"invalid format - no dash", "1000", 0, 0, true},
|
|
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
|
|
{"invalid start port", "abc-2000", 0, 0, true},
|
|
{"invalid end port", "1000-abc", 0, 0, true},
|
|
{"out of range start", "70000-80000", 0, 0, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.val != "" {
|
|
t.Setenv("ALLOWED_PORTS", tt.val)
|
|
} else {
|
|
err := os.Unsetenv("ALLOWED_PORTS")
|
|
assert.NoError(t, err)
|
|
}
|
|
start, end, err := parseAllowedPorts()
|
|
if tt.expectErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.start, start)
|
|
assert.Equal(t, tt.end, end)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseBufferSize(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
val string
|
|
expect int
|
|
}{
|
|
{"valid size", "8192", 8192},
|
|
{"default size", "", 32768},
|
|
{"too small", "1024", 4096},
|
|
{"too large", "2000000", 4096},
|
|
{"invalid format", "abc", 4096},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.val != "" {
|
|
t.Setenv("BUFFER_SIZE", tt.val)
|
|
} else {
|
|
err := os.Unsetenv("BUFFER_SIZE")
|
|
assert.NoError(t, err)
|
|
}
|
|
size := parseBufferSize()
|
|
assert.Equal(t, tt.expect, size)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseHeaderSize(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
val string
|
|
expect int
|
|
}{
|
|
{"valid size", "8192", 8192},
|
|
{"default size", "", 4096},
|
|
{"too small", "1024", 4096},
|
|
{"too large", "2000000", 4096},
|
|
{"invalid format", "abc", 4096},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if tt.val != "" {
|
|
t.Setenv("MAX_HEADER_SIZE", tt.val)
|
|
} else {
|
|
err := os.Unsetenv("MAX_HEADER_SIZE")
|
|
assert.NoError(t, err)
|
|
}
|
|
size := parseHeaderSize()
|
|
assert.Equal(t, tt.expect, size)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParse(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
envs map[string]string
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "minimal valid config",
|
|
envs: map[string]string{
|
|
"DOMAIN": "example.com",
|
|
},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "TLS enabled without token",
|
|
envs: map[string]string{
|
|
"TLS_ENABLED": "true",
|
|
},
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "TLS enabled with token",
|
|
envs: map[string]string{
|
|
"TLS_ENABLED": "true",
|
|
"CF_API_TOKEN": "secret",
|
|
},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "Node mode without token",
|
|
envs: map[string]string{
|
|
"MODE": "node",
|
|
},
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "Node mode with token",
|
|
envs: map[string]string{
|
|
"MODE": "node",
|
|
"NODE_TOKEN": "token",
|
|
},
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "invalid mode",
|
|
envs: map[string]string{
|
|
"MODE": "invalid",
|
|
},
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "invalid allowed ports",
|
|
envs: map[string]string{
|
|
"ALLOWED_PORTS": "1000",
|
|
},
|
|
expectErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
os.Clearenv()
|
|
for k, v := range tt.envs {
|
|
t.Setenv(k, v)
|
|
}
|
|
cfg, err := parse()
|
|
if tt.expectErr {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, cfg)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, cfg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetters(t *testing.T) {
|
|
envs := map[string]string{
|
|
"DOMAIN": "example.com",
|
|
"PORT": "2222",
|
|
"HTTP_PORT": "80",
|
|
"HTTPS_PORT": "443",
|
|
"KEY_LOC": "certs/ssh/id_rsa",
|
|
"TLS_ENABLED": "true",
|
|
"TLS_REDIRECT": "true",
|
|
"TLS_STORAGE_PATH": "certs/tls/",
|
|
"ACME_EMAIL": "test@example.com",
|
|
"CF_API_TOKEN": "token",
|
|
"ACME_STAGING": "true",
|
|
"ALLOWED_PORTS": "1000-2000",
|
|
"BUFFER_SIZE": "16384",
|
|
"MAX_HEADER_SIZE": "4096",
|
|
"PPROF_ENABLED": "true",
|
|
"PPROF_PORT": "7070",
|
|
"MODE": "standalone",
|
|
"GRPC_ADDRESS": "127.0.0.1",
|
|
"GRPC_PORT": "9090",
|
|
"NODE_TOKEN": "ntoken",
|
|
}
|
|
|
|
os.Clearenv()
|
|
for k, v := range envs {
|
|
t.Setenv(k, v)
|
|
}
|
|
|
|
cfg, err := parse()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, "example.com", cfg.Domain())
|
|
assert.Equal(t, "2222", cfg.SSHPort())
|
|
assert.Equal(t, "80", cfg.HTTPPort())
|
|
assert.Equal(t, "443", cfg.HTTPSPort())
|
|
assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
|
|
assert.Equal(t, true, cfg.TLSEnabled())
|
|
assert.Equal(t, true, cfg.TLSRedirect())
|
|
assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
|
|
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
|
|
assert.Equal(t, "token", cfg.CFAPIToken())
|
|
assert.Equal(t, true, cfg.ACMEStaging())
|
|
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
|
|
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
|
|
assert.Equal(t, 16384, cfg.BufferSize())
|
|
assert.Equal(t, 4096, cfg.HeaderSize())
|
|
assert.Equal(t, true, cfg.PprofEnabled())
|
|
assert.Equal(t, "7070", cfg.PprofPort())
|
|
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
|
|
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
|
|
assert.Equal(t, "9090", cfg.GRPCPort())
|
|
assert.Equal(t, "ntoken", cfg.NodeToken())
|
|
}
|
|
|
|
func TestMustLoad(t *testing.T) {
|
|
t.Run("success", func(t *testing.T) {
|
|
os.Clearenv()
|
|
t.Setenv("DOMAIN", "example.com")
|
|
cfg, err := MustLoad()
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, cfg)
|
|
})
|
|
|
|
t.Run("loadEnvFile error", func(t *testing.T) {
|
|
err := os.Mkdir(".env", 0755)
|
|
assert.NoError(t, err)
|
|
defer func() {
|
|
err = os.Remove(".env")
|
|
assert.NoError(t, err)
|
|
}()
|
|
|
|
cfg, err := MustLoad()
|
|
assert.Error(t, err)
|
|
assert.Nil(t, cfg)
|
|
})
|
|
|
|
t.Run("parse error", func(t *testing.T) {
|
|
os.Clearenv()
|
|
t.Setenv("MODE", "invalid")
|
|
cfg, err := MustLoad()
|
|
assert.Error(t, err)
|
|
assert.Nil(t, cfg)
|
|
})
|
|
}
|
|
|
|
func TestLoadEnvFile(t *testing.T) {
|
|
t.Run("file exists", func(t *testing.T) {
|
|
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
|
|
assert.NoError(t, err)
|
|
defer func() {
|
|
err = os.Remove(".env")
|
|
assert.NoError(t, err)
|
|
}()
|
|
|
|
err = loadEnvFile()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
|
|
})
|
|
|
|
t.Run("file missing", func(t *testing.T) {
|
|
_ = os.Remove(".env")
|
|
err := loadEnvFile()
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|