test(transport): add unit tests for transport behavior using Testify
SonarQube Scan / SonarQube Trigger (push) Successful in 1m51s

This commit is contained in:
2026-01-22 19:22:35 +07:00
parent 9d03f5507f
commit b0249c45ae
13 changed files with 1346 additions and 26 deletions
+178
View File
@@ -0,0 +1,178 @@
package transport
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"testing"
"time"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func TestValidateCertDomains_NotFound(t *testing.T) {
result := ValidateCertDomains("nonexistent.pem", "example.com")
assert.False(t, result)
}
func TestValidateCertDomains_InvalidPEM(t *testing.T) {
tmpFile, err := os.CreateTemp("", "invalid*.pem")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, _ = tmpFile.WriteString("not a pem")
tmpFile.Close()
result := ValidateCertDomains(tmpFile.Name(), "example.com")
assert.False(t, result)
}
func TestTLSManager_getTLSConfig(t *testing.T) {
tm := &tlsManager{
useCertMagic: false,
}
cfg := tm.getTLSConfig()
assert.NotNil(t, cfg)
assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
}
func TestTLSManager_getCertificate_Magic(t *testing.T) {
tm := &tlsManager{
useCertMagic: true,
}
hello := &tls.ClientHelloInfo{}
assert.Panics(t, func() {
_, _ = tm.getCertificate(hello)
})
}
func TestTLSManager_userCertsExistAndValid(t *testing.T) {
tm := &tlsManager{
certPath: "nonexistent.pem",
keyPath: "nonexistent.key",
}
assert.False(t, tm.userCertsExistAndValid())
keyFile, _ := os.CreateTemp("", "key*.pem")
defer os.Remove(keyFile.Name())
tm.keyPath = keyFile.Name()
assert.False(t, tm.userCertsExistAndValid())
}
func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, soon bool) (string, string) {
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
notAfter := time.Now().Add(365 * 24 * time.Hour)
if expired {
notAfter = time.Now().Add(-24 * time.Hour)
} else if soon {
notAfter = time.Now().Add(15 * 24 * time.Hour)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: domain,
},
NotBefore: time.Now().Add(-24 * time.Hour),
NotAfter: notAfter,
DNSNames: []string{domain},
}
if wildcard {
template.DNSNames = append(template.DNSNames, "*."+domain)
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
assert.NoError(t, err)
certOut, _ := os.CreateTemp("", "cert*.pem")
_ = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
keyOut, _ := os.CreateTemp("", "key*.pem")
_ = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
keyOut.Close()
return certOut.Name(), keyOut.Name()
}
func TestValidateCertDomains_Success(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.True(t, result)
}
func TestValidateCertDomains_Expired(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, true, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.False(t, result)
}
func TestValidateCertDomains_ExpiringSoon(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, true)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.False(t, result)
}
func TestValidateCertDomains_MissingWildcard(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.False(t, result)
}
func TestTLSManager_loadUserCerts_Success(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
tm := &tlsManager{
certPath: certPath,
keyPath: keyPath,
}
err := tm.loadUserCerts()
assert.NoError(t, err)
assert.NotNil(t, tm.userCert)
}