From d2e508c8ef88a131f22005f7dda658172b5623b2 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 23 Jan 2026 14:17:18 +0700 Subject: [PATCH] test(key): add unit tests for key behavior --- internal/key/key.go | 25 ++++- internal/key/key_test.go | 235 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 6 deletions(-) create mode 100644 internal/key/key_test.go diff --git a/internal/key/key.go b/internal/key/key.go index 659abe3..682a244 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "io" "log" "os" "path/filepath" @@ -12,6 +13,18 @@ import ( "golang.org/x/crypto/ssh" ) +var ( + rsaGenerateKey = rsa.GenerateKey + pemEncode = pem.Encode + sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) { + return ssh.NewPublicKey(key) + } + pubKeyWrite = func(w io.Writer, data []byte) (int, error) { + return w.Write(data) + } + osOpenFile = os.OpenFile +) + func GenerateSSHKeyIfNotExist(keyPath string) error { if _, err := os.Stat(keyPath); err == nil { log.Printf("SSH key already exists at %s", keyPath) @@ -20,7 +33,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error { log.Printf("SSH key not found at %s, generating new key pair...", keyPath) - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + privateKey, err := rsaGenerateKey(rand.Reader, 4096) if err != nil { return err } @@ -35,29 +48,29 @@ func GenerateSSHKeyIfNotExist(keyPath string) error { return err } - privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } defer privateKeyFile.Close() - if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil { + if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil { return err } - publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + publicKey, err := sshNewPublicKey(&privateKey.PublicKey) if err != nil { return err } pubKeyPath := keyPath + ".pub" - pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { return err } defer pubKeyFile.Close() - _, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey)) + _, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey)) if err != nil { return err } diff --git a/internal/key/key_test.go b/internal/key/key_test.go new file mode 100644 index 0000000..d28c33b --- /dev/null +++ b/internal/key/key_test.go @@ -0,0 +1,235 @@ +package key + +import ( + "crypto/rsa" + "encoding/pem" + "errors" + "io" + "os" + "path/filepath" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestGenerateSSHKeyIfNotExist(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + setup func(t *testing.T, tempDir string) string + mockSetup func() func() + wantErr bool + errStr string + verify func(t *testing.T, keyPath string) + }{ + { + name: "GenerateNewKey", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "id_rsa") + }, + verify: func(t *testing.T, keyPath string) { + pubKeyPath := keyPath + ".pub" + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("Private key file not created") + } + if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) { + t.Errorf("Public key file not created") + } + privateKeyBytes, err := os.ReadFile(keyPath) + if err != nil { + t.Fatalf("Failed to read private key: %v", err) + } + if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil { + t.Errorf("Failed to parse private key: %v", err) + } + publicKeyBytes, err := os.ReadFile(pubKeyPath) + if err != nil { + t.Fatalf("Failed to read public key: %v", err) + } + if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil { + t.Errorf("Failed to parse public key: %v", err) + } + }, + }, + { + name: "DoNotOverwriteExistingKey", + setup: func(t *testing.T, tempDir string) string { + keyPath := filepath.Join(tempDir, "existing_id_rsa") + dummyPrivate := "dummy private" + dummyPublic := "dummy public" + if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil { + t.Fatalf("Failed to create dummy private key: %v", err) + } + if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil { + t.Fatalf("Failed to create dummy public key: %v", err) + } + return keyPath + }, + verify: func(t *testing.T, keyPath string) { + gotPrivate, _ := os.ReadFile(keyPath) + if string(gotPrivate) != "dummy private" { + t.Errorf("Private key was overwritten") + } + gotPublic, _ := os.ReadFile(keyPath + ".pub") + if string(gotPublic) != "dummy public" { + t.Errorf("Public key was overwritten") + } + }, + }, + { + name: "CreateNestedDirectories", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "nested", "dir", "id_rsa") + }, + verify: func(t *testing.T, keyPath string) { + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("Private key file not created in nested directory") + } + }, + }, + { + name: "FailureMkdirAll", + setup: func(t *testing.T, tempDir string) string { + dirPath := filepath.Join(tempDir, "file_as_dir") + if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil { + t.Fatalf("Failed to create file: %v", err) + } + return filepath.Join(dirPath, "id_rsa") + }, + wantErr: true, + }, + { + name: "PrivateExistsPublicMissing", + setup: func(t *testing.T, tempDir string) string { + keyPath := filepath.Join(tempDir, "partial_id_rsa") + if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil { + t.Fatalf("Failed to create private key: %v", err) + } + return keyPath + }, + verify: func(t *testing.T, keyPath string) { + if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) { + t.Errorf("Public key should NOT have been created if private key existed") + } + }, + }, + { + name: "FailureRSAGenerateKey", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_rsa") + }, + mockSetup: func() func() { + old := rsaGenerateKey + rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) { + return nil, errors.New("rsa error") + } + return func() { rsaGenerateKey = old } + }, + wantErr: true, + errStr: "rsa error", + }, + { + name: "FailureOpenFilePrivate", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_open_private") + }, + mockSetup: func() func() { + old := osOpenFile + osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return nil, errors.New("open error") + } + return func() { osOpenFile = old } + }, + wantErr: true, + errStr: "open error", + }, + { + name: "FailurePemEncode", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_pem") + }, + mockSetup: func() func() { + old := pemEncode + pemEncode = func(out io.Writer, b *pem.Block) error { + return errors.New("pem error") + } + return func() { pemEncode = old } + }, + wantErr: true, + errStr: "pem error", + }, + { + name: "FailureSSHNewPublicKey", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_ssh") + }, + mockSetup: func() func() { + old := sshNewPublicKey + sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) { + return nil, errors.New("ssh error") + } + return func() { sshNewPublicKey = old } + }, + wantErr: true, + errStr: "ssh error", + }, + { + name: "FailureOpenFilePublic", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_open_public") + }, + mockSetup: func() func() { + old := osOpenFile + osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + if filepath.Ext(name) == ".pub" { + return nil, errors.New("open pub error") + } + return os.OpenFile(name, flag, perm) + } + return func() { osOpenFile = old } + }, + wantErr: true, + errStr: "open pub error", + }, + { + name: "FailurePubKeyWrite", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_write") + }, + mockSetup: func() func() { + old := pubKeyWrite + pubKeyWrite = func(w io.Writer, data []byte) (int, error) { + return 0, errors.New("write error") + } + return func() { pubKeyWrite = old } + }, + wantErr: true, + errStr: "write error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyPath := tt.setup(t, tempDir) + if tt.mockSetup != nil { + cleanup := tt.mockSetup() + defer cleanup() + } + + err := GenerateSSHKeyIfNotExist(keyPath) + + if (err != nil) != tt.wantErr { + t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr { + t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr) + } + + if tt.verify != nil { + tt.verify(t, keyPath) + } + }) + } +}