From c0916ceefc95f3baa25bc521a0129030c164acb2 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Thu, 20 Nov 2025 09:52:34 +0100 Subject: [PATCH 01/13] change: introduce new implementation of dbcrypt package --- pkg/dbcrypt/cipher.go | 191 +++++++++++++++++++++ pkg/dbcrypt/crypto.go | 308 ++++++++++++++++++++++++++++++++++ pkg/dbcrypt/dbcryptv2.go | 90 ++++++++++ pkg/dbcrypt/dbcryptv2_test.go | 224 +++++++++++++++++++++++++ pkg/dbcrypt/value.go | 80 +++++++++ 5 files changed, 893 insertions(+) create mode 100644 pkg/dbcrypt/cipher.go create mode 100644 pkg/dbcrypt/crypto.go create mode 100644 pkg/dbcrypt/dbcryptv2.go create mode 100644 pkg/dbcrypt/dbcryptv2_test.go create mode 100644 pkg/dbcrypt/value.go diff --git a/pkg/dbcrypt/cipher.go b/pkg/dbcrypt/cipher.go new file mode 100644 index 0000000..7c475f7 --- /dev/null +++ b/pkg/dbcrypt/cipher.go @@ -0,0 +1,191 @@ +// SPDX-FileCopyrightText: 2025 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package dbcrypt + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" + "io" + + "golang.org/x/crypto/argon2" +) + +type Config struct { + // Default version. Useful for testing older historical implementations or disabling encryption. Leave empty to use the most recent version. + // + // - use for v2 + // - use "ENCV2" for v2 + // - use "ENC" for v1 + // - use "PLAIN" to disable encryption + Version string + + // Contains the password used deriving encryption key + Password string + + // Contains the salt for increasing password entropy + PasswordSalt string +} + +var dbCiphers = []func(Config) (dbCipher, error){ + newDbCipherPlain, + newDbCipherV1, + newDbCipherV2, +} + +const defaultCipherPrefix = "ENCV2" + +type dbCipher interface { + Prefix() string + Encrypt(plaintext []byte) ([]byte, error) + Decrypt(ciphertext []byte) ([]byte, error) +} + +type dbCipherPlain struct { +} + +func newDbCipherPlain(Config) (dbCipher, error) { + return dbCipherPlain{}, nil +} + +func (c dbCipherPlain) Prefix() string { + return "PLAIN" +} + +func (c dbCipherPlain) Encrypt(plaintext []byte) ([]byte, error) { + return plaintext, nil +} + +func (c dbCipherPlain) Decrypt(ciphertext []byte) ([]byte, error) { + return ciphertext, nil +} + +type dbCipherV1 struct { + key []byte +} + +func newDbCipherV1(conf Config) (dbCipher, error) { + // Historically "v1" uses key truncation to 32 bytes. It needs to be preserved for backward compatibility. + key := []byte(conf.Password + conf.PasswordSalt)[:32] + return dbCipherV1{key: key}, nil +} + +func (c dbCipherV1) Prefix() string { + return "ENC" +} + +func (c dbCipherV1) Encrypt(plaintext []byte) ([]byte, error) { + block, err := aes.NewCipher(c.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + iv := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil) + ciphertextWithIv := append(iv, ciphertext...) + encoded := hex.AppendEncode(nil, ciphertextWithIv) + return encoded, nil +} + +func (c dbCipherV1) Decrypt(encoded []byte) ([]byte, error) { + ciphertextWithIv, err := hex.AppendDecode(nil, encoded) + if err != nil { + return nil, fmt.Errorf("error decoding ciphertext: %w", err) + } + + if len(ciphertextWithIv) < aes.BlockSize+1 { + return nil, fmt.Errorf("ciphertext too short") + } + + block, err := aes.NewCipher(c.key) + if err != nil { + return nil, fmt.Errorf("error creating AES cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + iv := ciphertextWithIv[:gcm.NonceSize()] + ciphertext := ciphertextWithIv[gcm.NonceSize():] + + plaintext, err := gcm.Open(nil, iv, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("error decrypting ciphertext: %w", err) + } + + return plaintext, nil +} + +type dbCipherV2 struct { + key []byte +} + +func newDbCipherV2(conf Config) (dbCipher, error) { + // v2" uses proper KDF (argon2id) to get the key + key := argon2.IDKey([]byte(conf.Password), []byte(conf.PasswordSalt), 1, 64*1024, 4, 32) + return dbCipherV2{key: key}, nil +} + +func (c dbCipherV2) Prefix() string { + return "ENCV2" +} + +func (c dbCipherV2) Encrypt(plaintext []byte) ([]byte, error) { + block, err := aes.NewCipher(c.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCMWithRandomNonce(block) + if err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nil, nil, []byte(plaintext), nil) + encoded := base64.StdEncoding.AppendEncode(nil, ciphertext) + return encoded, nil +} + +func (c dbCipherV2) Decrypt(encoded []byte) ([]byte, error) { + ciphertext, err := base64.StdEncoding.AppendDecode(nil, encoded) + if err != nil { + return nil, fmt.Errorf("error decoding ciphertext: %w", err) + } + + if len(ciphertext) < aes.BlockSize+1 { + return nil, fmt.Errorf("ciphertext too short") + } + + block, err := aes.NewCipher(c.key) + if err != nil { + return nil, fmt.Errorf("error creating AES cipher: %w", err) + } + + gcm, err := cipher.NewGCMWithRandomNonce(block) + if err != nil { + return nil, err + } + + plaintext, err := gcm.Open(nil, nil, ciphertext, nil) + if err != nil { + return nil, fmt.Errorf("error decrypting ciphertext: %w", err) + } + + return plaintext, nil +} diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/crypto.go new file mode 100644 index 0000000..6e0547a --- /dev/null +++ b/pkg/dbcrypt/crypto.go @@ -0,0 +1,308 @@ +// SPDX-FileCopyrightText: 2025 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package dbcrypt + +import ( + "errors" + "fmt" + "reflect" + + "gorm.io/gorm" +) + +// EncryptString encrypts the given string using the provided DBCryptV2. +func EncryptString(c *DBCryptV2, plaintext string) (string, error) { + fmt.Println("--debug before encryption", plaintext) // TODO: Remove me + ciphertext, err := c.Encrypt([]byte(plaintext)) + if err != nil { + return "", err + } + fmt.Println("--debug after encryption", string(ciphertext)) // TODO: Remove me + return string(ciphertext), nil +} + +// DecryptString decrypts the given string using the provided DBCryptV2. +func DecryptString(c *DBCryptV2, ciphertextWithPrefix string) (string, error) { + fmt.Println("--debug before decryption", ciphertextWithPrefix) // TODO: Remove me + plaintext, err := c.Decrypt([]byte(ciphertextWithPrefix)) + if err != nil { + return "", err + } + fmt.Println("--debug after decryption", string(plaintext)) // TODO: Remove me + return string(plaintext), nil +} + +func asPointerToStruct(x any) reflect.Value { + v := reflect.ValueOf(x) + if v.Kind() != reflect.Pointer { + return reflect.Value{} + } + el := v.Elem() + if el.Kind() != reflect.Struct { + return reflect.Value{} + } + return v +} + +func shouldStructFieldBeEncrypted(sf reflect.StructField) (bool, error) { + encryptTag, has := sf.Tag.Lookup("encrypt") + if !has || encryptTag == "false" { + return false, nil + } + if encryptTag != "true" { + return false, fmt.Errorf("invalid value for 'encrypt' field tag %q", encryptTag) + } + if !sf.IsExported() { + return false, errors.New("unexported field marked for encryption") + } + if sf.Type.Kind() != reflect.String { + return false, errors.New("invalid type of field marked for encryption") + } + return true, nil +} + +func validateHistoricalTags(sf reflect.StructField) error { + if _, has := sf.Tag.Lookup("encrypt"); has { + return errors.New("support 'encrypt' struct filed tag has been removed with new DBCrypt package, use dbcrypt.EncryptedString type instead") + } + return nil +} + +// EncryptStruct encrypts all fields withing the given struct that are tagged with 'encrypt:"true"' using the provided DBCryptV2. +func EncryptStruct(c *DBCryptV2, plaintext any) error { + v := asPointerToStruct(plaintext) + if !v.IsValid() { + return errors.New("invalid value provided to struct encryption (expected a pointer to struct)") + } + v = v.Elem() + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + f, fTyp := v.Field(i), typ.Field(i) + doEnc, err := shouldStructFieldBeEncrypted(fTyp) + if err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + if !doEnc { + continue + } + ciphertext, err := EncryptString(c, f.String()) + if err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + f.SetString(ciphertext) + } + return nil +} + +func EncryptAny(c *DBCryptV2, plaintext any) error { + value := reflect.ValueOf(plaintext) + if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { + return encryptRecursive(c, value) + } + if value.Kind() == reflect.Map { + return encryptRecursive(c, value) + } + return errors.New("invalid value provided for encryption") +} + +func encryptRecursive(c *DBCryptV2, plaintext reflect.Value) error { + if es, ok := plaintext.Interface().(EncryptedString); ok { + if err := es.Encrypt(c); err != nil { + return err + } + plaintext.Set(reflect.ValueOf(es)) + return nil + } + if plaintext.Kind() == reflect.Pointer || plaintext.Kind() == reflect.Interface { + if plaintext.IsNil() { + return nil + } + return encryptRecursive(c, plaintext.Elem()) + } + if plaintext.Kind() == reflect.Struct { + typ := plaintext.Type() + for i := 0; i < typ.NumField(); i++ { + fTyp := typ.Field(i) + if !fTyp.IsExported() { + continue + } + if err := validateHistoricalTags(fTyp); err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + if err := encryptRecursive(c, plaintext.Field(i)); err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + } + } + if plaintext.Kind() == reflect.Map { + for k, v := range plaintext.Seq2() { + if err := encryptRecursive(c, v); err != nil { + return fmt.Errorf("map key %q: %w", k.String(), err) + } + } + } + return nil +} + +func DecryptAny(c *DBCryptV2, ciphertext any) error { + value := reflect.ValueOf(ciphertext) + if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { + return decryptRecursive(c, value) + } + if value.Kind() == reflect.Map { + return decryptRecursive(c, value) + } + return errors.New("invalid value provided for decryption") +} + +func decryptRecursive(c *DBCryptV2, ciphertext reflect.Value) error { + if es, ok := ciphertext.Interface().(EncryptedString); ok { + if err := es.decrypt(c); err != nil { + return err + } + ciphertext.Set(reflect.ValueOf(es)) + return nil + } + if ciphertext.Kind() == reflect.Pointer || ciphertext.Kind() == reflect.Interface { + if ciphertext.IsNil() { + return nil + } + return decryptRecursive(c, ciphertext.Elem()) + } + if ciphertext.Kind() == reflect.Struct { + typ := ciphertext.Type() + for i := 0; i < typ.NumField(); i++ { + fTyp := typ.Field(i) + if !fTyp.IsExported() { + continue + } + if err := validateHistoricalTags(fTyp); err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + if err := decryptRecursive(c, ciphertext.Field(i)); err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + } + } + if ciphertext.Kind() == reflect.Map { + for k, v := range ciphertext.Seq2() { + if err := decryptRecursive(c, v); err != nil { + return fmt.Errorf("map key %q: %w", k.String(), err) + } + } + } + return nil +} + +// DecryptStruct decrypts all fields withing the given struct that are tagged with 'encrypt:"true"' using the provided DBCryptV2. +func DecryptStruct(c *DBCryptV2, ciphertext any) error { + v := asPointerToStruct(ciphertext) + if !v.IsValid() { + return errors.New("invalid value provided to struct decryption (expected a pointer to struct)") + } + v = v.Elem() + typ := v.Type() + for i := 0; i < typ.NumField(); i++ { + f, fTyp := v.Field(i), typ.Field(i) + doEnc, err := shouldStructFieldBeEncrypted(fTyp) + if err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + if !doEnc { + continue + } + plaintext, err := DecryptString(c, f.String()) + if err != nil { + return fmt.Errorf("field %q: %w", fTyp.Name, err) + } + f.SetString(plaintext) + } + return nil +} + +func encryptDBModel(c *DBCryptV2, plaintext any) error { + return EncryptAny(c, plaintext) +} + +func decryptDBModel(c *DBCryptV2, ciphertext any) error { + return DecryptAny(c, ciphertext) +} + +// Register registers encryption and decryption callbacks for the provided data base, to perform automatically cryptographic operations on all models using EncryptStruct and DecryptStruct functions. +func Register(db *gorm.DB, c *DBCryptV2) error { + encryptCb := func(db *gorm.DB) { + db.AddError(encryptDBModel(c, db.Statement.Dest)) + } + decryptCb := func(db *gorm.DB) { + db.AddError(decryptDBModel(c, db.Statement.Dest)) + } + + if err := db.Callback(). + Create(). + Before("gorm:create"). + Register("crypto:before_create", encryptCb); err != nil { + return err + } + if err := db.Callback(). + Create(). + After("gorm:create"). + Register("crypto:after_create", decryptCb); err != nil { + return err + } + if err := db.Callback(). + Update(). + Before("gorm:update"). + Register("crypto:before_update", encryptCb); err != nil { + return err + } + if err := db.Callback(). + Update(). + After("gorm:update"). + Register("crypto:after_update", decryptCb); err != nil { + return err + } + if err := db.Callback(). + Query(). + After("gorm:query"). + Register("crypto:after_query", decryptCb); err != nil { + return err + } + return nil +} + +// Deregister removes any encryption and decryption callbacks for the provided data base. +func Deregister(db *gorm.DB) error { + if err := db.Callback(). + Create(). + Before("gorm:create"). + Remove("crypto:before_create"); err != nil { + return err + } + if err := db.Callback(). + Create(). + After("gorm:create"). + Remove("crypto:after_create"); err != nil { + return err + } + if err := db.Callback(). + Update(). + Before("gorm:update"). + Remove("crypto:before_update"); err != nil { + return err + } + if err := db.Callback(). + Update(). + After("gorm:update"). + Remove("crypto:after_update"); err != nil { + return err + } + if err := db.Callback(). + Query(). + After("gorm:query"). + Remove("crypto:after_query"); err != nil { + return err + } + return nil +} diff --git a/pkg/dbcrypt/dbcryptv2.go b/pkg/dbcrypt/dbcryptv2.go new file mode 100644 index 0000000..b2121a5 --- /dev/null +++ b/pkg/dbcrypt/dbcryptv2.go @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package dbcrypt + +import ( + "bytes" + "errors" + "fmt" +) + +const prefixSeparator = ":" + +type DBCryptV2 struct { + encryptionCipher dbCipher + decryptionCiphers map[string]dbCipher +} + +// New creates a new instance of DBCryptV2 based on the provided Config. +func New(conf Config) (*DBCryptV2, error) { + if conf.Password == "" { + return nil, errors.New("db password is empty") + } + if conf.PasswordSalt == "" { + return nil, errors.New("db password salt is empty") + } + if len(conf.PasswordSalt) < 32 { + return nil, errors.New("db password salt is too short") + } + c := &DBCryptV2{} + if err := c.registerCiphers(conf); err != nil { + return nil, err + } + return c, nil +} + +func (c *DBCryptV2) registerCiphers(conf Config) error { + c.decryptionCiphers = map[string]dbCipher{} + for _, fn := range dbCiphers { + cipher, err := fn(conf) + if err != nil { + return err + } + c.decryptionCiphers[cipher.Prefix()] = cipher + } + + encryptionCipherPrefix := conf.Version + if encryptionCipherPrefix == "" { + encryptionCipherPrefix = defaultCipherPrefix + } + cipher := c.decryptionCiphers[encryptionCipherPrefix] + if cipher == nil { + return fmt.Errorf("invalid db cipher version %q", conf.Version) + } + c.encryptionCipher = cipher + return nil +} + +func (c *DBCryptV2) findDecryptionCipher(ciphertextWithPrefix []byte) (dbCipher, []byte, []byte) { + i := bytes.Index(ciphertextWithPrefix, []byte(prefixSeparator)) + if i < 0 { + return nil, nil, ciphertextWithPrefix + } + prefix, ciphertext := ciphertextWithPrefix[:i], ciphertextWithPrefix[i+len(prefixSeparator):] + return c.decryptionCiphers[string(prefix)], prefix, ciphertext +} + +func (c *DBCryptV2) Encrypt(plaintext []byte) ([]byte, error) { + ciphertext, err := c.encryptionCipher.Encrypt(plaintext) + if err != nil { + return nil, err + } + return append([]byte(c.encryptionCipher.Prefix()+prefixSeparator), ciphertext...), nil +} + +func (c *DBCryptV2) Decrypt(ciphertextWithPrefix []byte) ([]byte, error) { + cipher, prefix, ciphertext := c.findDecryptionCipher(ciphertextWithPrefix) + if len(prefix) == 0 { + return nil, errors.New("invalid encrypted value format") + } + if cipher == nil { + return nil, errors.New("unknown encrypted value format") + } + plaintext, err := cipher.Decrypt(ciphertext) + if err != nil { + return nil, err + } + return plaintext, nil +} diff --git a/pkg/dbcrypt/dbcryptv2_test.go b/pkg/dbcrypt/dbcryptv2_test.go new file mode 100644 index 0000000..cc26471 --- /dev/null +++ b/pkg/dbcrypt/dbcryptv2_test.go @@ -0,0 +1,224 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package dbcrypt_test + +import ( + "testing" + + "github.com/greenbone/opensight-golang-libraries/pkg/dbcrypt" + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func newTestDb[T any](t *testing.T) *gorm.DB { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + require.NoError(t, err) + var table T + err = db.AutoMigrate(&table) + require.NoError(t, err) + return db +} + +func TestGormCreateRead(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.Equal(t, givenData, gotData) +} + +func TestGormCreateReadNonPointerValue(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: *dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.Equal(t, givenData, gotData) +} + +// func TestGormCreateRead(t *testing.T) { +// type Model struct { +// ID uint `gorm:"primarykey"` +// Protected string `encrypt:"true"` +// } +// db := newTestDb[Model](t) +// crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) +// require.NoError(t, err) +// require.NoError(t, dbcrypt.Register(db, crypt)) + +// givenData := Model{Protected: "aaa"} +// require.NoError(t, db.Create(&givenData).Error) + +// gotData := Model{} +// require.NoError(t, db.First(&gotData).Error) +// require.Equal(t, givenData, gotData) +// } + +// func TestGormCreateReadRaw(t *testing.T) { +// type Model struct { +// ID uint `gorm:"primarykey"` +// Protected string `encrypt:"true"` +// } +// db := newTestDb[Model](t) +// crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) +// require.NoError(t, err) +// require.NoError(t, dbcrypt.Register(db, crypt)) + +// givenData := Model{Protected: "aaa"} +// require.NoError(t, db.Create(&givenData).Error) + +// require.NoError(t, dbcrypt.Deregister(db)) + +// gotData := Model{} +// require.NoError(t, db.First(&gotData).Error) +// require.NotEqual(t, givenData, gotData) +// } + +func TestGormCreateReadRaw(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + require.NoError(t, dbcrypt.Deregister(db)) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.NotEqual(t, givenData, gotData) + givenDataEncrypted, _ := givenData.Protected.Encrypted() + gotDataEncrypted, _ := gotData.Protected.Encrypted() + require.Equal(t, givenDataEncrypted, gotDataEncrypted) +} + +// func TestGormCreateUpdateRead(t *testing.T) { +// type Model struct { +// ID uint `gorm:"primarykey"` +// Protected string `encrypt:"true"` +// } +// db := newTestDb[Model](t) +// crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) +// require.NoError(t, err) +// require.NoError(t, dbcrypt.Register(db, crypt)) + +// givenData := Model{Protected: "aaa"} +// require.NoError(t, db.Create(&givenData).Error) + +// updatedData := Model{ID: givenData.ID, Protected: "bbb"} +// require.NoError(t, db.Model(&updatedData).Updates(&updatedData).Error) + +// gotData := Model{} +// require.NoError(t, db.First(&gotData).Error) +// require.Equal(t, updatedData, gotData) +// } + +func TestGormCreateUpdateRead(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} + require.NoError(t, db.Updates(&updatedData).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.Equal(t, updatedData, gotData) +} + +func TestGormCreateColumnUpdateRead(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} + require.NoError(t, db.Model(&Model{}).Where("id = ?", updatedData.ID).Update("protected", dbcrypt.NewEncryptedString("bbb")).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + gotData.Protected.ClearEncrypted() + require.Equal(t, updatedData, gotData) +} + +// func TestGormMixDBCryptInstances(t *testing.T) { +// type Model struct { +// ID uint `gorm:"primarykey"` +// Protected string `encrypt:"true"` +// } +// db := newTestDb[Model](t) +// cryptFirst, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) +// require.NoError(t, err) +// cryptSecond, err := dbcrypt.New(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) +// require.NoError(t, err) +// require.NoError(t, dbcrypt.Register(db, cryptFirst)) + +// givenData := Model{Protected: "aaa"} +// require.NoError(t, db.Create(&givenData).Error) + +// require.NoError(t, dbcrypt.Deregister(db)) +// require.NoError(t, dbcrypt.Register(db, cryptSecond)) +// require.Error(t, db.First(&Model{}).Error) +// } + +func TestGormMixDBCryptInstances(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + cryptFirst, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + cryptSecond, err := dbcrypt.New(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, cryptFirst)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + require.NoError(t, dbcrypt.Deregister(db)) + require.NoError(t, dbcrypt.Register(db, cryptSecond)) + require.Error(t, db.First(&Model{}).Error) +} diff --git a/pkg/dbcrypt/value.go b/pkg/dbcrypt/value.go new file mode 100644 index 0000000..7d46297 --- /dev/null +++ b/pkg/dbcrypt/value.go @@ -0,0 +1,80 @@ +package dbcrypt + +import ( + "database/sql/driver" + "errors" +) + +type EncryptedString struct { + encrypted string + decrypted string +} + +func NewEncryptedString(val string) *EncryptedString { + return &EncryptedString{decrypted: val} +} + +func (es *EncryptedString) Scan(v any) error { + enc, ok := v.(string) + if !ok { + return errors.New("failed to unmarshal encrypted string value") + } + es.encrypted, es.decrypted = enc, "" + return nil +} + +func (es EncryptedString) Value() (driver.Value, error) { + enc, ok := es.Encrypted() + if !ok { + return nil, errors.New("cannot store string value: encryption required") + } + if enc == "" { + return nil, nil + } + return enc, nil +} + +func (es *EncryptedString) Encrypted() (string, bool) { + if es == nil { + return "", true + } + has := es.encrypted != "" || es.decrypted == "" + return es.encrypted, has +} + +func (es *EncryptedString) Encrypt(c *DBCryptV2) error { + enc, err := EncryptString(c, es.decrypted) + if err != nil { + return err + } + es.encrypted = enc + return nil +} + +func (es *EncryptedString) ClearEncrypted() { + es.encrypted = "" +} + +func (es *EncryptedString) decrypt(c *DBCryptV2) error { + enc, ok := es.Encrypted() + if !ok || enc == "" { + return nil + } + dec, err := DecryptString(c, enc) + if err != nil { + return err + } + es.decrypted = dec + return nil +} + +func (es *EncryptedString) Get() string { + if es == nil { + return "" + } + return es.decrypted +} + +func (es *EncryptedString) Set(to string) { + es.encrypted, es.decrypted = "", to +} From aad275c908f965f86a358e33d9de84bf186c8e8b Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Fri, 21 Nov 2025 13:44:13 +0100 Subject: [PATCH 02/13] change: refine new implementation of dbcrypt package --- pkg/dbcrypt/README.md | 128 +++++-------------- pkg/dbcrypt/cipher.go | 45 +------ pkg/dbcrypt/cipher_test.go | 215 ++++++++++++++++++++++++++++++++ pkg/dbcrypt/config/config.go | 29 ----- pkg/dbcrypt/crypto.go | 180 ++++++--------------------- pkg/dbcrypt/dbcipher.go | 104 ++++++++++++++++ pkg/dbcrypt/dbcrypt.go | 142 --------------------- pkg/dbcrypt/dbcrypt_test.go | 113 ----------------- pkg/dbcrypt/dbcryptv2.go | 90 -------------- pkg/dbcrypt/dbcryptv2_test.go | 224 ---------------------------------- pkg/dbcrypt/gorm_test.go | 198 ++++++++++++++++++++++++++++++ pkg/dbcrypt/value.go | 34 ++++-- 12 files changed, 610 insertions(+), 892 deletions(-) create mode 100644 pkg/dbcrypt/cipher_test.go delete mode 100644 pkg/dbcrypt/config/config.go create mode 100644 pkg/dbcrypt/dbcipher.go delete mode 100644 pkg/dbcrypt/dbcrypt.go delete mode 100644 pkg/dbcrypt/dbcrypt_test.go delete mode 100644 pkg/dbcrypt/dbcryptv2.go delete mode 100644 pkg/dbcrypt/dbcryptv2_test.go create mode 100644 pkg/dbcrypt/gorm_test.go diff --git a/pkg/dbcrypt/README.md b/pkg/dbcrypt/README.md index 124822d..765f60f 100644 --- a/pkg/dbcrypt/README.md +++ b/pkg/dbcrypt/README.md @@ -2,10 +2,7 @@ # dbcrypt Package Documentation -This package provides functions for encrypting and decrypting fields of entities persisted with GORM -using the AES algorithm. It uses the GCM mode of operation for encryption, which provides authentication and integrity -protection for the encrypted data. -It can be used to encrypt / decrypt sensitive data using gorm hooks (see example) +This package provides functions for encrypting and decrypting fields of entities persisted with GORM using the AES algorithm. It uses the GCM mode of operation for encryption, which provides authentication and integrity protection for the encrypted data. It can be used to encrypt and decrypt sensitive data using gorm hooks. ## Example Usage @@ -15,125 +12,56 @@ Here is an example of how to use the dbcrypt package: package main import ( - "fmt" + "log" - "github.com/example/dbcrypt" + "github.com/greenbone/opensight-golang-libraries/pkg/dbcrypt" ) type Person struct { gorm.Model - Field1 string - PwdField string `encrypt:"true"` + PasswordField *dbcrypt.EncryptedString `encrypt:"true"` } -func (a *MyTable) encrypt(tx *gorm.DB) (err error) { - err = cryptor.EncryptStruct(a) +func main() { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) if err != nil { - err := tx.AddError(fmt.Errorf("unable to encrypt password %w", err)) - if err != nil { - return err - } - return err + log.Fatalf("Error %v", err) } - return nil -} - -func (a *MyTable) BeforeCreate(tx *gorm.DB) (err error) { - return a.encrypt(tx) -} -func (a *MyTable) AfterFind(tx *gorm.DB) (err error) { - err = cryptor.DecryptStruct(a) + cipher, err := dbcrypt.NewDBCipher(dbcrypt.Config{ + Password: "password", + PasswordSalt: "password-salt-0123456789-0123456", + }) if err != nil { - err := tx.AddError(fmt.Errorf("Unable to decrypt password %w", err)) - if err != nil { - return err - } - return err + log.Fatalf("Error %v", err) } - return nil -} - -``` - -In this example, a Person struct is created and encrypted using the DBCrypt struct. The encrypted struct is then saved to the database. Finally the struct is decrypted when the gorm hook is -activated. - ---- - - - - - -# dbcrypt - -```go -import "github.com/greenbone/opensight-golang-libraries/pkg/dbcrypt" -``` - -## Index - -- [func Decrypt\(encrypted string, key \[\]byte\) \(string, error\)](<#Decrypt>) -- [func Encrypt\(plaintext string, key \[\]byte\) \(string, error\)](<#Encrypt>) -- [type DBCrypt](<#DBCrypt>) - - [func \(d \*DBCrypt\[T\]\) DecryptStruct\(data \*T\) error](<#DBCrypt[T].DecryptStruct>) - - [func \(d \*DBCrypt\[T\]\) EncryptStruct\(data \*T\) error](<#DBCrypt[T].EncryptStruct>) - - - -## func Decrypt - -```go -func Decrypt(encrypted string, key []byte) (string, error) -``` - - - - -## func Encrypt - -```go -func Encrypt(plaintext string, key []byte) (string, error) -``` - - - - -## type DBCrypt - + dbcrypt.Register(db, cipher) + personWrite := &Person{PasswordField: dbcrypt.NewEncryptedString("secret")} + if err := db.Create(personWrite).Error; err != nil { + log.Fatalf("Error %v", err) + } -```go -type DBCrypt[T any] struct { - // contains filtered or unexported fields + personRead := &Person{} + if err := db.First(personRead).Error; err != nil { + log.Fatalf("Error %v", err) + } } ``` - -### func \(\*DBCrypt\[T\]\) DecryptStruct +In this example, a Person struct is created and `PasswordField` is automatically encrypted before storing in the database using the DBCipher. Then, when the data is retrieved from the database `PasswordField` is automatically decrypted. -```go -func (d *DBCrypt[T]) DecryptStruct(data *T) error -``` - -DecryptStruct decrypts all fields of a struct that are tagged with \`encrypt:"true"\` - - -### func \(\*DBCrypt\[T\]\) EncryptStruct +Alternatively while creating the model you can use a tags instead of dedicated types: ```go -func (d *DBCrypt[T]) EncryptStruct(data *T) error +type Person struct { + gorm.Model + PasswordField string `encrypt:"true"` +} ``` -EncryptStruct encrypts all fields of a struct that are tagged with \`encrypt:"true"\` - -Generated by [gomarkdoc]() - - - - # License Copyright (C) 2022-2023 [Greenbone AG][Greenbone AG] -Licensed under the [GNU General Public License v3.0 or later](../../LICENSE). \ No newline at end of file +Licensed under the [GNU General Public License v3.0 or later](../../LICENSE). diff --git a/pkg/dbcrypt/cipher.go b/pkg/dbcrypt/cipher.go index 7c475f7..bf69dff 100644 --- a/pkg/dbcrypt/cipher.go +++ b/pkg/dbcrypt/cipher.go @@ -16,55 +16,12 @@ import ( "golang.org/x/crypto/argon2" ) -type Config struct { - // Default version. Useful for testing older historical implementations or disabling encryption. Leave empty to use the most recent version. - // - // - use for v2 - // - use "ENCV2" for v2 - // - use "ENC" for v1 - // - use "PLAIN" to disable encryption - Version string - - // Contains the password used deriving encryption key - Password string - - // Contains the salt for increasing password entropy - PasswordSalt string -} - -var dbCiphers = []func(Config) (dbCipher, error){ - newDbCipherPlain, - newDbCipherV1, - newDbCipherV2, -} - -const defaultCipherPrefix = "ENCV2" - type dbCipher interface { Prefix() string Encrypt(plaintext []byte) ([]byte, error) Decrypt(ciphertext []byte) ([]byte, error) } -type dbCipherPlain struct { -} - -func newDbCipherPlain(Config) (dbCipher, error) { - return dbCipherPlain{}, nil -} - -func (c dbCipherPlain) Prefix() string { - return "PLAIN" -} - -func (c dbCipherPlain) Encrypt(plaintext []byte) ([]byte, error) { - return plaintext, nil -} - -func (c dbCipherPlain) Decrypt(ciphertext []byte) ([]byte, error) { - return ciphertext, nil -} - type dbCipherV1 struct { key []byte } @@ -137,7 +94,7 @@ type dbCipherV2 struct { } func newDbCipherV2(conf Config) (dbCipher, error) { - // v2" uses proper KDF (argon2id) to get the key + // "v2" uses proper KDF (argon2id) to get the key key := argon2.IDKey([]byte(conf.Password), []byte(conf.PasswordSalt), 1, 64*1024, 4, 32) return dbCipherV2{key: key}, nil } diff --git a/pkg/dbcrypt/cipher_test.go b/pkg/dbcrypt/cipher_test.go new file mode 100644 index 0000000..16ef9bb --- /dev/null +++ b/pkg/dbcrypt/cipher_test.go @@ -0,0 +1,215 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package dbcrypt_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/greenbone/opensight-golang-libraries/pkg/dbcrypt" +) + +func TestCipherEncryptAndDecrypt(t *testing.T) { + tests := []struct { + name string + config dbcrypt.Config + given string + }{ + { + name: "latest/random", + config: dbcrypt.Config{ + Version: "", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + given: uuid.NewString(), + }, + { + name: "v2/random", + config: dbcrypt.Config{ + Version: "v2", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + given: uuid.NewString(), + }, + { + name: "v2/empty", + config: dbcrypt.Config{ + Version: "v2", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + given: "", + }, + { + name: "v2/prefix", + config: dbcrypt.Config{ + Version: "v2", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + given: "ENCV2:", + }, + { + name: "v1/random", + config: dbcrypt.Config{ + Version: "v1", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + given: uuid.NewString(), + }, + { + name: "v1/empty", + config: dbcrypt.Config{ + Version: "v1", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + given: "", + }, + { + name: "v1/prefix", + config: dbcrypt.Config{ + Version: "v1", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + given: "ENC:", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c, err := dbcrypt.NewDBCipher(test.config) + require.NoError(t, err) + + ciphertext, err := c.Encrypt([]byte(test.given)) + require.NoError(t, err) + require.NotEqual(t, test.given, string(ciphertext)) + + got, err := c.Decrypt(ciphertext) + require.NoError(t, err) + require.Equal(t, test.given, string(got)) + }) + } +} + +func TestCipherCreationFailure(t *testing.T) { + tests := []struct { + name string + config dbcrypt.Config + errorShouldContain string + }{ + { + name: "unknown-version", + config: dbcrypt.Config{ + Version: "unknown", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + errorShouldContain: "cipher version", + }, + { + name: "empty-password", + config: dbcrypt.Config{ + Version: "", + Password: "", + PasswordSalt: "encryption-password-salt-0123456", + }, + errorShouldContain: "password is empty", + }, + { + name: "empty-salt", + config: dbcrypt.Config{ + Version: "", + Password: "encryption-password", + PasswordSalt: "", + }, + errorShouldContain: "salt is empty", + }, + { + name: "salt-too-short", + config: dbcrypt.Config{ + Version: "", + Password: "encryption-password", + PasswordSalt: "short-salt", + }, + errorShouldContain: "salt is too short", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := dbcrypt.NewDBCipher(test.config) + require.Error(t, err) + require.Contains(t, err.Error(), test.errorShouldContain) + }) + } +} + +func TestHistoricalDataDecryption(t *testing.T) { + tests := []struct { + name string + config dbcrypt.Config + encrypted string + decrypted string + }{ + { + name: "v1/simple", + config: dbcrypt.Config{ + Version: "v1", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + encrypted: "ENC:425378be21051852fbbc94dca44e1837039ab68f61b02e02bb731c3de0930b6bdb00", + decrypted: "FooBar", + }, + { + name: "v1/salt-truncation", // "v1" historically uses insecure password-salt truncation, this test checks preservation of this behavior + config: dbcrypt.Config{ + Version: "v1", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456789-0123456789", + }, + encrypted: "ENC:425378be21051852fbbc94dca44e1837039ab68f61b02e02bb731c3de0930b6bdb00", + decrypted: "FooBar", + }, + { + name: "v1/password-truncation", // "v1" historically uses insecure password-salt truncation, this test checks preservation of this behavior + config: dbcrypt.Config{ + Version: "v1", + Password: "encryption-password-0123456789-0123456789", + PasswordSalt: "0123456789-0123456789-0123456789", + }, + encrypted: "ENC:d18d84eb52946f069ee6b967c84657f9d9cf7d89940685ad348a161d2e212f16f5be", + decrypted: "FooBar", + }, + { + name: "v2/simple", + config: dbcrypt.Config{ + Version: "v2", + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }, + encrypted: "ENCV2:xDuCgHSIWYBuyONI1w9rFAXas7Z7ReaTZfYfy2VH0A1DrQ==", + decrypted: "FooBar", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c, err := dbcrypt.NewDBCipher(test.config) + require.NoError(t, err) + + got, err := c.Decrypt([]byte(test.encrypted)) + require.NoError(t, err) + require.Equal(t, test.decrypted, string(got)) + }) + } +} diff --git a/pkg/dbcrypt/config/config.go b/pkg/dbcrypt/config/config.go deleted file mode 100644 index 283da9a..0000000 --- a/pkg/dbcrypt/config/config.go +++ /dev/null @@ -1,29 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Greenbone AG -// -// SPDX-License-Identifier: AGPL-3.0-or-later - -package config - -import ( - "github.com/greenbone/opensight-golang-libraries/pkg/configReader" -) - -// CryptoConfig defines the configuration for service-wide cryptography options. -// -// Version specific options will apply to all newer encryption versions up to the variants with a new specific option, -// e.g. if the options MyKeyV1 and MyKeyV4 exist, MyKeyV1 will apply to v1, v2 and v3, while MyKeyV4 applies to v4 and -// newer. -type CryptoConfig struct { - // Contains the password for encrypting user group specific report encryptions using v1 to v2 - ReportEncryptionV1Password string `validate:"required" viperEnv:"TASK_REPORT_CRYPTO_V1_PASSWORD"` - // Contains the salt for encrypting user group specific report encryptions v1 to v2 - ReportEncryptionV1Salt string `validate:"required,gte=32" viperEnv:"TASK_REPORT_CRYPTO_V1_SALT"` -} - -func Read() (config CryptoConfig, err error) { - _, err = configReader.ReadEnvVarsIntoStruct(&config) - if err != nil { - return config, err - } - return config, nil -} diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/crypto.go index 6e0547a..79ab8f5 100644 --- a/pkg/dbcrypt/crypto.go +++ b/pkg/dbcrypt/crypto.go @@ -12,41 +12,7 @@ import ( "gorm.io/gorm" ) -// EncryptString encrypts the given string using the provided DBCryptV2. -func EncryptString(c *DBCryptV2, plaintext string) (string, error) { - fmt.Println("--debug before encryption", plaintext) // TODO: Remove me - ciphertext, err := c.Encrypt([]byte(plaintext)) - if err != nil { - return "", err - } - fmt.Println("--debug after encryption", string(ciphertext)) // TODO: Remove me - return string(ciphertext), nil -} - -// DecryptString decrypts the given string using the provided DBCryptV2. -func DecryptString(c *DBCryptV2, ciphertextWithPrefix string) (string, error) { - fmt.Println("--debug before decryption", ciphertextWithPrefix) // TODO: Remove me - plaintext, err := c.Decrypt([]byte(ciphertextWithPrefix)) - if err != nil { - return "", err - } - fmt.Println("--debug after decryption", string(plaintext)) // TODO: Remove me - return string(plaintext), nil -} - -func asPointerToStruct(x any) reflect.Value { - v := reflect.ValueOf(x) - if v.Kind() != reflect.Pointer { - return reflect.Value{} - } - el := v.Elem() - if el.Kind() != reflect.Struct { - return reflect.Value{} - } - return v -} - -func shouldStructFieldBeEncrypted(sf reflect.StructField) (bool, error) { +func parseEncryptStructFieldTag(sf reflect.StructField) (bool, error) { encryptTag, has := sf.Tag.Lookup("encrypt") if !has || encryptTag == "false" { return false, nil @@ -63,40 +29,7 @@ func shouldStructFieldBeEncrypted(sf reflect.StructField) (bool, error) { return true, nil } -func validateHistoricalTags(sf reflect.StructField) error { - if _, has := sf.Tag.Lookup("encrypt"); has { - return errors.New("support 'encrypt' struct filed tag has been removed with new DBCrypt package, use dbcrypt.EncryptedString type instead") - } - return nil -} - -// EncryptStruct encrypts all fields withing the given struct that are tagged with 'encrypt:"true"' using the provided DBCryptV2. -func EncryptStruct(c *DBCryptV2, plaintext any) error { - v := asPointerToStruct(plaintext) - if !v.IsValid() { - return errors.New("invalid value provided to struct encryption (expected a pointer to struct)") - } - v = v.Elem() - typ := v.Type() - for i := 0; i < typ.NumField(); i++ { - f, fTyp := v.Field(i), typ.Field(i) - doEnc, err := shouldStructFieldBeEncrypted(fTyp) - if err != nil { - return fmt.Errorf("field %q: %w", fTyp.Name, err) - } - if !doEnc { - continue - } - ciphertext, err := EncryptString(c, f.String()) - if err != nil { - return fmt.Errorf("field %q: %w", fTyp.Name, err) - } - f.SetString(ciphertext) - } - return nil -} - -func EncryptAny(c *DBCryptV2, plaintext any) error { +func encryptModel(c *DBCipher, plaintext any) error { value := reflect.ValueOf(plaintext) if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { return encryptRecursive(c, value) @@ -107,7 +40,7 @@ func EncryptAny(c *DBCryptV2, plaintext any) error { return errors.New("invalid value provided for encryption") } -func encryptRecursive(c *DBCryptV2, plaintext reflect.Value) error { +func encryptRecursive(c *DBCipher, plaintext reflect.Value) error { if es, ok := plaintext.Interface().(EncryptedString); ok { if err := es.Encrypt(c); err != nil { return err @@ -128,7 +61,7 @@ func encryptRecursive(c *DBCryptV2, plaintext reflect.Value) error { if !fTyp.IsExported() { continue } - if err := validateHistoricalTags(fTyp); err != nil { + if err := encryptFieldBasedOnTag(c, fTyp, plaintext.Field(i)); err != nil { return fmt.Errorf("field %q: %w", fTyp.Name, err) } if err := encryptRecursive(c, plaintext.Field(i)); err != nil { @@ -146,7 +79,23 @@ func encryptRecursive(c *DBCryptV2, plaintext reflect.Value) error { return nil } -func DecryptAny(c *DBCryptV2, ciphertext any) error { +func encryptFieldBasedOnTag(c *DBCipher, sf reflect.StructField, val reflect.Value) error { + tagValue, err := parseEncryptStructFieldTag(sf) + if err != nil { + return err + } + if !tagValue { + return nil + } + ciphertext, err := c.Encrypt([]byte(val.String())) + if err != nil { + return err + } + val.SetString(string(ciphertext)) + return nil +} + +func decryptModel(c *DBCipher, ciphertext any) error { value := reflect.ValueOf(ciphertext) if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { return decryptRecursive(c, value) @@ -157,7 +106,7 @@ func DecryptAny(c *DBCryptV2, ciphertext any) error { return errors.New("invalid value provided for decryption") } -func decryptRecursive(c *DBCryptV2, ciphertext reflect.Value) error { +func decryptRecursive(c *DBCipher, ciphertext reflect.Value) error { if es, ok := ciphertext.Interface().(EncryptedString); ok { if err := es.decrypt(c); err != nil { return err @@ -178,7 +127,7 @@ func decryptRecursive(c *DBCryptV2, ciphertext reflect.Value) error { if !fTyp.IsExported() { continue } - if err := validateHistoricalTags(fTyp); err != nil { + if err := decryptFieldBasedOnTag(c, fTyp, ciphertext.Field(i)); err != nil { return fmt.Errorf("field %q: %w", fTyp.Name, err) } if err := decryptRecursive(c, ciphertext.Field(i)); err != nil { @@ -196,47 +145,29 @@ func decryptRecursive(c *DBCryptV2, ciphertext reflect.Value) error { return nil } -// DecryptStruct decrypts all fields withing the given struct that are tagged with 'encrypt:"true"' using the provided DBCryptV2. -func DecryptStruct(c *DBCryptV2, ciphertext any) error { - v := asPointerToStruct(ciphertext) - if !v.IsValid() { - return errors.New("invalid value provided to struct decryption (expected a pointer to struct)") +func decryptFieldBasedOnTag(c *DBCipher, sf reflect.StructField, val reflect.Value) error { + tagValue, err := parseEncryptStructFieldTag(sf) + if err != nil { + return err } - v = v.Elem() - typ := v.Type() - for i := 0; i < typ.NumField(); i++ { - f, fTyp := v.Field(i), typ.Field(i) - doEnc, err := shouldStructFieldBeEncrypted(fTyp) - if err != nil { - return fmt.Errorf("field %q: %w", fTyp.Name, err) - } - if !doEnc { - continue - } - plaintext, err := DecryptString(c, f.String()) - if err != nil { - return fmt.Errorf("field %q: %w", fTyp.Name, err) - } - f.SetString(plaintext) + if !tagValue { + return nil + } + plaintext, err := c.Decrypt([]byte(val.String())) + if err != nil { + return err } + val.SetString(string(plaintext)) return nil } -func encryptDBModel(c *DBCryptV2, plaintext any) error { - return EncryptAny(c, plaintext) -} - -func decryptDBModel(c *DBCryptV2, ciphertext any) error { - return DecryptAny(c, ciphertext) -} - -// Register registers encryption and decryption callbacks for the provided data base, to perform automatically cryptographic operations on all models using EncryptStruct and DecryptStruct functions. -func Register(db *gorm.DB, c *DBCryptV2) error { +// Register registers encryption and decryption callbacks for the provided data base, to perform automatically cryptographic operations on all models that contain a value of type EncryptedString or a file tagged with 'encrypt:"true"'. +func Register(db *gorm.DB, c *DBCipher) error { encryptCb := func(db *gorm.DB) { - db.AddError(encryptDBModel(c, db.Statement.Dest)) + db.AddError(encryptModel(c, db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. } decryptCb := func(db *gorm.DB) { - db.AddError(decryptDBModel(c, db.Statement.Dest)) + db.AddError(decryptModel(c, db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. } if err := db.Callback(). @@ -271,38 +202,3 @@ func Register(db *gorm.DB, c *DBCryptV2) error { } return nil } - -// Deregister removes any encryption and decryption callbacks for the provided data base. -func Deregister(db *gorm.DB) error { - if err := db.Callback(). - Create(). - Before("gorm:create"). - Remove("crypto:before_create"); err != nil { - return err - } - if err := db.Callback(). - Create(). - After("gorm:create"). - Remove("crypto:after_create"); err != nil { - return err - } - if err := db.Callback(). - Update(). - Before("gorm:update"). - Remove("crypto:before_update"); err != nil { - return err - } - if err := db.Callback(). - Update(). - After("gorm:update"). - Remove("crypto:after_update"); err != nil { - return err - } - if err := db.Callback(). - Query(). - After("gorm:query"). - Remove("crypto:after_query"); err != nil { - return err - } - return nil -} diff --git a/pkg/dbcrypt/dbcipher.go b/pkg/dbcrypt/dbcipher.go new file mode 100644 index 0000000..9f43302 --- /dev/null +++ b/pkg/dbcrypt/dbcipher.go @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package dbcrypt + +import ( + "bytes" + "errors" + "fmt" +) + +const prefixSeparator = ":" + +// Config encapsulates configuration for DBCipher. +type Config struct { + // Default version of the cryptographic algorithm. Useful for testing older historical implementations. Leave empty to use the most recent version. + // + // - use for v2 version of the cryptographic algorithm + // - use "v2" for v2 version of the cryptographic algorithm + // - use "v1" for v1 version of the cryptographic algorithm + Version string + + // Contains the password used deriving encryption key + Password string + + // Contains the salt for increasing password entropy + PasswordSalt string +} + +// DBCipher is cipher designed to perform validated encryption and decryption on database values. +type DBCipher struct { + encryptionCipher dbCipher + decryptionCiphers map[string]dbCipher +} + +// NewDBCipher creates a new instance of DBCipher based on the provided Config. +func NewDBCipher(conf Config) (*DBCipher, error) { + if conf.Password == "" { + return nil, errors.New("db password is empty") + } + if conf.PasswordSalt == "" { + return nil, errors.New("db password salt is empty") + } + if len(conf.PasswordSalt) < 32 { + return nil, errors.New("db password salt is too short") + } + c := &DBCipher{} + if err := c.registerCiphers(conf); err != nil { + return nil, err + } + return c, nil +} + +func (c *DBCipher) registerCiphers(conf Config) error { + v2, err := newDbCipherV2(conf) + if err != nil { + return err + } + v1, err := newDbCipherV1(conf) + if err != nil { + return err + } + + c.decryptionCiphers = map[string]dbCipher{ + v2.Prefix(): v2, + v1.Prefix(): v1, + } + switch conf.Version { + case "", "v2": + c.encryptionCipher = v2 + case "v1": + c.encryptionCipher = v1 + default: + return fmt.Errorf("invalid db cipher version %q", conf.Version) + } + return nil +} + +// Encrypt encrypts the provided bytes with DBCipher. +func (c *DBCipher) Encrypt(plaintext []byte) ([]byte, error) { + ciphertext, err := c.encryptionCipher.Encrypt(plaintext) + if err != nil { + return nil, err + } + return append([]byte(c.encryptionCipher.Prefix()+prefixSeparator), ciphertext...), nil +} + +// Decrypt decrypts the provided bytes with DBCipher. +func (c *DBCipher) Decrypt(ciphertextWithPrefix []byte) ([]byte, error) { + prefix, ciphertext, hasSeparator := bytes.Cut(ciphertextWithPrefix, []byte(prefixSeparator)) + if !hasSeparator { + return nil, errors.New("invalid encrypted value format") + } + cipher := c.decryptionCiphers[string(prefix)] + if cipher == nil { + return nil, errors.New("unknown encrypted value format") + } + plaintext, err := cipher.Decrypt(ciphertext) + if err != nil { + return nil, err + } + return plaintext, nil +} diff --git a/pkg/dbcrypt/dbcrypt.go b/pkg/dbcrypt/dbcrypt.go deleted file mode 100644 index 49b6e59..0000000 --- a/pkg/dbcrypt/dbcrypt.go +++ /dev/null @@ -1,142 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Greenbone AG -// -// SPDX-License-Identifier: AGPL-3.0-or-later - -package dbcrypt - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/hex" - "fmt" - "io" - "reflect" - - "github.com/greenbone/opensight-golang-libraries/pkg/dbcrypt/config" - - "github.com/rs/zerolog/log" -) - -const ( - prefix = "ENC:" - prefixLen = len(prefix) -) - -type DBCrypt[T any] struct { - config config.CryptoConfig -} - -func (d *DBCrypt[T]) loadKey() []byte { - if d.config == (config.CryptoConfig{}) { - conf, err := config.Read() - if err != nil { - log.Fatal().Err(err).Msg("crypto config is invalid") - } - d.config = conf - } - key := []byte(d.config.ReportEncryptionV1Password + d.config.ReportEncryptionV1Salt)[:32] // Truncate key to 32 bytes - return key -} - -// EncryptStruct encrypts all fields of a struct that are tagged with `encrypt:"true"` -func (d *DBCrypt[T]) EncryptStruct(data *T) error { - key := d.loadKey() - value := reflect.ValueOf(data).Elem() - valueType := value.Type() - for i := 0; i < value.NumField(); i++ { - field := value.Field(i) - fieldType := valueType.Field(i) - if encrypt, ok := fieldType.Tag.Lookup("encrypt"); ok && encrypt == "true" { - plaintext := fmt.Sprintf("%v", field.Interface()) - if len(plaintext) > prefixLen && plaintext[:prefixLen] == prefix { - // already encrypted goto next field - continue - } - ciphertext, err := Encrypt(plaintext, key) - if err != nil { - return err - } - field.SetString(ciphertext) - } - } - return nil -} - -// DecryptStruct decrypts all fields of a struct that are tagged with `encrypt:"true"` -func (d *DBCrypt[T]) DecryptStruct(data *T) error { - key := d.loadKey() - value := reflect.ValueOf(data).Elem() - valueType := value.Type() - for i := 0; i < value.NumField(); i++ { - field := value.Field(i) - fieldType := valueType.Field(i) - if encrypt, ok := fieldType.Tag.Lookup("encrypt"); ok && encrypt == "true" { - plaintext, err := Decrypt(field.String(), key) - if err != nil { - return err - } - field.SetString(plaintext) - } - } - return nil -} - -func Encrypt(plaintext string, key []byte) (string, error) { - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - iv := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, iv); err != nil { - return "", err - } - - ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil) - - encoded := hex.EncodeToString(append(iv, ciphertext...)) - return prefix + encoded, nil -} - -func Decrypt(encrypted string, key []byte) (string, error) { - if len(encrypted) <= prefixLen || encrypted[:prefixLen] != prefix { - return "", fmt.Errorf("invalid encrypted value format") - } - - encodedCiphertext := encrypted[4:] - - ciphertext, err := hex.DecodeString(encodedCiphertext) - if err != nil { - return "", fmt.Errorf("error decoding ciphertext: %w", err) - } - - if len(ciphertext) < aes.BlockSize+1 { - return "", fmt.Errorf("ciphertext too short") - } - - block, err := aes.NewCipher(key) - if err != nil { - return "", fmt.Errorf("error creating AES cipher: %w", err) - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - iv := ciphertext[:gcm.NonceSize()] - ciphertext = ciphertext[gcm.NonceSize():] - - plaintext, err := gcm.Open(nil, iv, ciphertext, nil) - if err != nil { - return "", fmt.Errorf("error decrypting ciphertext: %w", err) - } - - return string(plaintext), nil -} diff --git a/pkg/dbcrypt/dbcrypt_test.go b/pkg/dbcrypt/dbcrypt_test.go deleted file mode 100644 index 67632e2..0000000 --- a/pkg/dbcrypt/dbcrypt_test.go +++ /dev/null @@ -1,113 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Greenbone AG -// -// SPDX-License-Identifier: AGPL-3.0-or-later - -package dbcrypt - -import ( - "fmt" - "os" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gorm.io/driver/sqlite" - "gorm.io/gorm" -) - -type MyTable struct { - gorm.Model - Field1 string - PwdField string `encrypt:"true"` -} - -var cryptor = DBCrypt[MyTable]{} - -func (a *MyTable) encrypt(tx *gorm.DB) (err error) { - err = cryptor.EncryptStruct(a) - if err != nil { - err := tx.AddError(fmt.Errorf("unable to encrypt password %w", err)) - if err != nil { - return err - } - return err - } - return nil -} - -func (a *MyTable) BeforeCreate(tx *gorm.DB) (err error) { - return a.encrypt(tx) -} - -func (a *MyTable) BeforeUpdate(tx *gorm.DB) (err error) { - return a.encrypt(tx) -} - -func (a *MyTable) BeforeSave(tx *gorm.DB) (err error) { - return a.encrypt(tx) -} - -func (a *MyTable) AfterFind(tx *gorm.DB) (err error) { - err = cryptor.DecryptStruct(a) - if err != nil { - err := tx.AddError(fmt.Errorf("unable to decrypt password %w", err)) - if err != nil { - return err - } - return err - } - return nil -} - -func getTestDb(t *testing.T) *gorm.DB { - // db, err := gorm.Open(sqlite.Open("file:memory:?"), &gorm.Config{}) - db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) - require.NoError(t, err) - err = db.AutoMigrate(&MyTable{}) - require.NoError(t, err) - return db -} - -func TestEncryptDecrypt(t *testing.T) { - os.Setenv("TASK_REPORT_CRYPTO_V1_PASSWORD", "my-key-1234567890") - os.Setenv("TASK_REPORT_CRYPTO_V1_SALT", "my-salt-0987654321-0987654321-09") - defer func() { - os.Unsetenv("TASK_REPORT_CRYPTO_V1_PASSWORD") - os.Unsetenv("TASK_REPORT_CRYPTO_V1_SALT") - }() - - clearData := &MyTable{ - Field1: "111111111", - PwdField: "ThePassword", - } - originalPw := clearData.PwdField - - cryptor := DBCrypt[MyTable]{} - err := cryptor.EncryptStruct(clearData) - require.NoError(t, err) - require.NotEqual(t, originalPw, clearData.PwdField, "password was not encrypted") - err = cryptor.DecryptStruct(clearData) - require.NoError(t, err) - assert.Equal(t, originalPw, clearData.PwdField) -} - -func TestApplianceEncryption(t *testing.T) { - os.Setenv("TASK_REPORT_CRYPTO_V1_PASSWORD", "my-key-1234567890") - os.Setenv("TASK_REPORT_CRYPTO_V1_SALT", "my-salt-0987654321-0987654321-09") - defer func() { - os.Unsetenv("TASK_REPORT_CRYPTO_V1_PASSWORD") - os.Unsetenv("TASK_REPORT_CRYPTO_V1_SALT") - }() - - myDB := getTestDb(t) - tblData := &MyTable{ - Field1: "ajdf", - PwdField: "thePasswordWhichCanBeEncrypted", - } - myDB.Create(tblData) - assert.NotNil(t, tblData.ID) - - resultData := &MyTable{} - myDB.First(&resultData, tblData.ID) - assert.EqualValues(t, "thePasswordWhichCanBeEncrypted", resultData.PwdField) -} diff --git a/pkg/dbcrypt/dbcryptv2.go b/pkg/dbcrypt/dbcryptv2.go deleted file mode 100644 index b2121a5..0000000 --- a/pkg/dbcrypt/dbcryptv2.go +++ /dev/null @@ -1,90 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Greenbone AG -// -// SPDX-License-Identifier: AGPL-3.0-or-later - -package dbcrypt - -import ( - "bytes" - "errors" - "fmt" -) - -const prefixSeparator = ":" - -type DBCryptV2 struct { - encryptionCipher dbCipher - decryptionCiphers map[string]dbCipher -} - -// New creates a new instance of DBCryptV2 based on the provided Config. -func New(conf Config) (*DBCryptV2, error) { - if conf.Password == "" { - return nil, errors.New("db password is empty") - } - if conf.PasswordSalt == "" { - return nil, errors.New("db password salt is empty") - } - if len(conf.PasswordSalt) < 32 { - return nil, errors.New("db password salt is too short") - } - c := &DBCryptV2{} - if err := c.registerCiphers(conf); err != nil { - return nil, err - } - return c, nil -} - -func (c *DBCryptV2) registerCiphers(conf Config) error { - c.decryptionCiphers = map[string]dbCipher{} - for _, fn := range dbCiphers { - cipher, err := fn(conf) - if err != nil { - return err - } - c.decryptionCiphers[cipher.Prefix()] = cipher - } - - encryptionCipherPrefix := conf.Version - if encryptionCipherPrefix == "" { - encryptionCipherPrefix = defaultCipherPrefix - } - cipher := c.decryptionCiphers[encryptionCipherPrefix] - if cipher == nil { - return fmt.Errorf("invalid db cipher version %q", conf.Version) - } - c.encryptionCipher = cipher - return nil -} - -func (c *DBCryptV2) findDecryptionCipher(ciphertextWithPrefix []byte) (dbCipher, []byte, []byte) { - i := bytes.Index(ciphertextWithPrefix, []byte(prefixSeparator)) - if i < 0 { - return nil, nil, ciphertextWithPrefix - } - prefix, ciphertext := ciphertextWithPrefix[:i], ciphertextWithPrefix[i+len(prefixSeparator):] - return c.decryptionCiphers[string(prefix)], prefix, ciphertext -} - -func (c *DBCryptV2) Encrypt(plaintext []byte) ([]byte, error) { - ciphertext, err := c.encryptionCipher.Encrypt(plaintext) - if err != nil { - return nil, err - } - return append([]byte(c.encryptionCipher.Prefix()+prefixSeparator), ciphertext...), nil -} - -func (c *DBCryptV2) Decrypt(ciphertextWithPrefix []byte) ([]byte, error) { - cipher, prefix, ciphertext := c.findDecryptionCipher(ciphertextWithPrefix) - if len(prefix) == 0 { - return nil, errors.New("invalid encrypted value format") - } - if cipher == nil { - return nil, errors.New("unknown encrypted value format") - } - plaintext, err := cipher.Decrypt(ciphertext) - if err != nil { - return nil, err - } - return plaintext, nil -} diff --git a/pkg/dbcrypt/dbcryptv2_test.go b/pkg/dbcrypt/dbcryptv2_test.go deleted file mode 100644 index cc26471..0000000 --- a/pkg/dbcrypt/dbcryptv2_test.go +++ /dev/null @@ -1,224 +0,0 @@ -// SPDX-FileCopyrightText: 2024 Greenbone AG -// -// SPDX-License-Identifier: AGPL-3.0-or-later - -package dbcrypt_test - -import ( - "testing" - - "github.com/greenbone/opensight-golang-libraries/pkg/dbcrypt" - "github.com/stretchr/testify/require" - "gorm.io/driver/sqlite" - "gorm.io/gorm" -) - -func newTestDb[T any](t *testing.T) *gorm.DB { - db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) - require.NoError(t, err) - var table T - err = db.AutoMigrate(&table) - require.NoError(t, err) - return db -} - -func TestGormCreateRead(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - require.Equal(t, givenData, gotData) -} - -func TestGormCreateReadNonPointerValue(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: *dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - require.Equal(t, givenData, gotData) -} - -// func TestGormCreateRead(t *testing.T) { -// type Model struct { -// ID uint `gorm:"primarykey"` -// Protected string `encrypt:"true"` -// } -// db := newTestDb[Model](t) -// crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) -// require.NoError(t, err) -// require.NoError(t, dbcrypt.Register(db, crypt)) - -// givenData := Model{Protected: "aaa"} -// require.NoError(t, db.Create(&givenData).Error) - -// gotData := Model{} -// require.NoError(t, db.First(&gotData).Error) -// require.Equal(t, givenData, gotData) -// } - -// func TestGormCreateReadRaw(t *testing.T) { -// type Model struct { -// ID uint `gorm:"primarykey"` -// Protected string `encrypt:"true"` -// } -// db := newTestDb[Model](t) -// crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) -// require.NoError(t, err) -// require.NoError(t, dbcrypt.Register(db, crypt)) - -// givenData := Model{Protected: "aaa"} -// require.NoError(t, db.Create(&givenData).Error) - -// require.NoError(t, dbcrypt.Deregister(db)) - -// gotData := Model{} -// require.NoError(t, db.First(&gotData).Error) -// require.NotEqual(t, givenData, gotData) -// } - -func TestGormCreateReadRaw(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - require.NoError(t, dbcrypt.Deregister(db)) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - require.NotEqual(t, givenData, gotData) - givenDataEncrypted, _ := givenData.Protected.Encrypted() - gotDataEncrypted, _ := gotData.Protected.Encrypted() - require.Equal(t, givenDataEncrypted, gotDataEncrypted) -} - -// func TestGormCreateUpdateRead(t *testing.T) { -// type Model struct { -// ID uint `gorm:"primarykey"` -// Protected string `encrypt:"true"` -// } -// db := newTestDb[Model](t) -// crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) -// require.NoError(t, err) -// require.NoError(t, dbcrypt.Register(db, crypt)) - -// givenData := Model{Protected: "aaa"} -// require.NoError(t, db.Create(&givenData).Error) - -// updatedData := Model{ID: givenData.ID, Protected: "bbb"} -// require.NoError(t, db.Model(&updatedData).Updates(&updatedData).Error) - -// gotData := Model{} -// require.NoError(t, db.First(&gotData).Error) -// require.Equal(t, updatedData, gotData) -// } - -func TestGormCreateUpdateRead(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} - require.NoError(t, db.Updates(&updatedData).Error) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - require.Equal(t, updatedData, gotData) -} - -func TestGormCreateColumnUpdateRead(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} - require.NoError(t, db.Model(&Model{}).Where("id = ?", updatedData.ID).Update("protected", dbcrypt.NewEncryptedString("bbb")).Error) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - gotData.Protected.ClearEncrypted() - require.Equal(t, updatedData, gotData) -} - -// func TestGormMixDBCryptInstances(t *testing.T) { -// type Model struct { -// ID uint `gorm:"primarykey"` -// Protected string `encrypt:"true"` -// } -// db := newTestDb[Model](t) -// cryptFirst, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) -// require.NoError(t, err) -// cryptSecond, err := dbcrypt.New(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) -// require.NoError(t, err) -// require.NoError(t, dbcrypt.Register(db, cryptFirst)) - -// givenData := Model{Protected: "aaa"} -// require.NoError(t, db.Create(&givenData).Error) - -// require.NoError(t, dbcrypt.Deregister(db)) -// require.NoError(t, dbcrypt.Register(db, cryptSecond)) -// require.Error(t, db.First(&Model{}).Error) -// } - -func TestGormMixDBCryptInstances(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - cryptFirst, err := dbcrypt.New(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - cryptSecond, err := dbcrypt.New(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, cryptFirst)) - - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - require.NoError(t, dbcrypt.Deregister(db)) - require.NoError(t, dbcrypt.Register(db, cryptSecond)) - require.Error(t, db.First(&Model{}).Error) -} diff --git a/pkg/dbcrypt/gorm_test.go b/pkg/dbcrypt/gorm_test.go new file mode 100644 index 0000000..a246f6e --- /dev/null +++ b/pkg/dbcrypt/gorm_test.go @@ -0,0 +1,198 @@ +// SPDX-FileCopyrightText: 2024 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + +package dbcrypt_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + + "github.com/greenbone/opensight-golang-libraries/pkg/dbcrypt" +) + +func newTestDb[T any](t *testing.T) *gorm.DB { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + require.NoError(t, err) + var table T + err = db.AutoMigrate(&table) + require.NoError(t, err) + return db +} + +func TestGormCreateReadWithTag(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected string `encrypt:"true"` + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: "aaa"} + require.NoError(t, db.Create(&givenData).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.Equal(t, givenData, gotData) +} + +func TestGormCreateReadWithType(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.Equal(t, givenData, gotData) +} + +func TestGormCreateReadWithTypeNonPointerValue(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: *dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.Equal(t, givenData, gotData) +} + +func TestGormCreateReadRawWithTag(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected string `encrypt:"true"` + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: "aaa"} + require.NoError(t, db.Create(&givenData).Error) + + gotData := Model{} + require.NoError(t, db.Raw(`SELECT * FROM models LIMIT 1`).Scan(&gotData).Error) + require.NotEqual(t, givenData, gotData) +} + +func TestGormCreateReadRawWithType(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + gotData := Model{} + require.NoError(t, db.Raw(`SELECT * FROM models LIMIT 1`).Scan(&gotData).Error) + require.NotEqual(t, givenData, gotData) + givenDataEncrypted, _ := givenData.Protected.Encrypted() + gotDataEncrypted, _ := gotData.Protected.Encrypted() + require.Equal(t, givenDataEncrypted, gotDataEncrypted) +} + +func TestGormCreateUpdateReadWithType(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} + require.NoError(t, db.Updates(&updatedData).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + require.Equal(t, updatedData, gotData) +} + +func TestGormCreateColumnUpdateReadWithType(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, crypt)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} + require.NoError(t, db.Model(&Model{}).Where("id = ?", updatedData.ID).Update("protected", dbcrypt.NewEncryptedString("bbb")).Error) + + gotData := Model{} + require.NoError(t, db.First(&gotData).Error) + gotData.Protected.ClearEncrypted() + require.Equal(t, updatedData, gotData) +} + +func TestGormMixDBCryptInstancesWithTag(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected string `encrypt:"true"` + } + db := newTestDb[Model](t) + cryptFirst, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + cryptSecond, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, cryptFirst)) + + givenData := Model{Protected: "aaa"} + require.NoError(t, db.Create(&givenData).Error) + + require.NoError(t, dbcrypt.Register(db, cryptSecond)) + require.Error(t, db.First(&Model{}).Error) +} + +func TestGormMixDBCryptInstancesWithType(t *testing.T) { + type Model struct { + ID uint `gorm:"primarykey"` + Protected *dbcrypt.EncryptedString + } + db := newTestDb[Model](t) + cryptFirst, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + cryptSecond, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + require.NoError(t, err) + require.NoError(t, dbcrypt.Register(db, cryptFirst)) + + givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + require.NoError(t, db.Create(&givenData).Error) + + require.NoError(t, dbcrypt.Register(db, cryptSecond)) + require.Error(t, db.First(&Model{}).Error) +} diff --git a/pkg/dbcrypt/value.go b/pkg/dbcrypt/value.go index 7d46297..46f01bc 100644 --- a/pkg/dbcrypt/value.go +++ b/pkg/dbcrypt/value.go @@ -5,15 +5,27 @@ import ( "errors" ) +// EncryptedString is a wrapper around string that indicates that the value should be encrypted wile stored. type EncryptedString struct { encrypted string decrypted string } -func NewEncryptedString(val string) *EncryptedString { - return &EncryptedString{decrypted: val} +// NewEncryptedString creates a new EncryptedString based on plaintext data. Returned value, until encrypted, will miss associated ciphertext value. +func NewEncryptedString(dec string) *EncryptedString { + return &EncryptedString{decrypted: dec} } +// DecryptEncryptedString creates a new EncryptedString based on ciphertext data. It automatically decrypts it using the provided DBCipher, so both plaintext and ciphertext values are available. +func DecryptEncryptedString(c *DBCipher, enc string) (*EncryptedString, error) { + es := &EncryptedString{encrypted: enc} + if err := es.decrypt(c); err != nil { + return nil, err + } + return es, nil +} + +// Scan unmarshal encrypted stored value into EncryptedString. func (es *EncryptedString) Scan(v any) error { enc, ok := v.(string) if !ok { @@ -23,6 +35,7 @@ func (es *EncryptedString) Scan(v any) error { return nil } +// Value returns encrypted value for storing. func (es EncryptedString) Value() (driver.Value, error) { enc, ok := es.Encrypted() if !ok { @@ -34,6 +47,7 @@ func (es EncryptedString) Value() (driver.Value, error) { return enc, nil } +// Encrypted returns ciphertext (encrypted) value of EncryptedString and true, if encrypted value is available. Otherwise it return an empty string and false. func (es *EncryptedString) Encrypted() (string, bool) { if es == nil { return "", true @@ -42,32 +56,35 @@ func (es *EncryptedString) Encrypted() (string, bool) { return es.encrypted, has } -func (es *EncryptedString) Encrypt(c *DBCryptV2) error { - enc, err := EncryptString(c, es.decrypted) +// Encrypt generates a new ciphertext value based on plaintext value using the provided DBCipher. +func (es *EncryptedString) Encrypt(c *DBCipher) error { + enc, err := c.Encrypt([]byte(es.decrypted)) if err != nil { return err } - es.encrypted = enc + es.encrypted = string(enc) return nil } +// ClearEncrypted removes associated encrypted value. func (es *EncryptedString) ClearEncrypted() { es.encrypted = "" } -func (es *EncryptedString) decrypt(c *DBCryptV2) error { +func (es *EncryptedString) decrypt(c *DBCipher) error { enc, ok := es.Encrypted() if !ok || enc == "" { return nil } - dec, err := DecryptString(c, enc) + dec, err := c.Decrypt([]byte(enc)) if err != nil { return err } - es.decrypted = dec + es.decrypted = string(dec) return nil } +// Get returns plaintext (decrypted) value of EncryptedString. func (es *EncryptedString) Get() string { if es == nil { return "" @@ -75,6 +92,7 @@ func (es *EncryptedString) Get() string { return es.decrypted } +// Set sets plaintext (decrypted) value of EncryptedString. func (es *EncryptedString) Set(to string) { es.encrypted, es.decrypted = "", to } From 845f26ae63dd542a676395432903bfb6b3e436a8 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Fri, 21 Nov 2025 15:07:32 +0100 Subject: [PATCH 03/13] change: refine new implementation of dbcrypt package --- pkg/dbcrypt/README.md | 11 +--- pkg/dbcrypt/cipher_test.go | 2 +- pkg/dbcrypt/crypto.go | 16 +---- pkg/dbcrypt/dbcipher.go | 47 ++++++++----- pkg/dbcrypt/gorm_test.go | 132 +++++-------------------------------- pkg/dbcrypt/value.go | 98 --------------------------- 6 files changed, 51 insertions(+), 255 deletions(-) delete mode 100644 pkg/dbcrypt/value.go diff --git a/pkg/dbcrypt/README.md b/pkg/dbcrypt/README.md index 765f60f..11d51dd 100644 --- a/pkg/dbcrypt/README.md +++ b/pkg/dbcrypt/README.md @@ -19,7 +19,7 @@ import ( type Person struct { gorm.Model - PasswordField *dbcrypt.EncryptedString `encrypt:"true"` + PasswordField string `encrypt:"true"` } func main() { @@ -51,15 +51,6 @@ func main() { In this example, a Person struct is created and `PasswordField` is automatically encrypted before storing in the database using the DBCipher. Then, when the data is retrieved from the database `PasswordField` is automatically decrypted. -Alternatively while creating the model you can use a tags instead of dedicated types: - -```go -type Person struct { - gorm.Model - PasswordField string `encrypt:"true"` -} -``` - # License Copyright (C) 2022-2023 [Greenbone AG][Greenbone AG] diff --git a/pkg/dbcrypt/cipher_test.go b/pkg/dbcrypt/cipher_test.go index 16ef9bb..ea4afc4 100644 --- a/pkg/dbcrypt/cipher_test.go +++ b/pkg/dbcrypt/cipher_test.go @@ -113,7 +113,7 @@ func TestCipherCreationFailure(t *testing.T) { Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456", }, - errorShouldContain: "cipher version", + errorShouldContain: "invalid db cipher version", }, { name: "empty-password", diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/crypto.go index 79ab8f5..1a65521 100644 --- a/pkg/dbcrypt/crypto.go +++ b/pkg/dbcrypt/crypto.go @@ -41,13 +41,6 @@ func encryptModel(c *DBCipher, plaintext any) error { } func encryptRecursive(c *DBCipher, plaintext reflect.Value) error { - if es, ok := plaintext.Interface().(EncryptedString); ok { - if err := es.Encrypt(c); err != nil { - return err - } - plaintext.Set(reflect.ValueOf(es)) - return nil - } if plaintext.Kind() == reflect.Pointer || plaintext.Kind() == reflect.Interface { if plaintext.IsNil() { return nil @@ -107,13 +100,6 @@ func decryptModel(c *DBCipher, ciphertext any) error { } func decryptRecursive(c *DBCipher, ciphertext reflect.Value) error { - if es, ok := ciphertext.Interface().(EncryptedString); ok { - if err := es.decrypt(c); err != nil { - return err - } - ciphertext.Set(reflect.ValueOf(es)) - return nil - } if ciphertext.Kind() == reflect.Pointer || ciphertext.Kind() == reflect.Interface { if ciphertext.IsNil() { return nil @@ -161,7 +147,7 @@ func decryptFieldBasedOnTag(c *DBCipher, sf reflect.StructField, val reflect.Val return nil } -// Register registers encryption and decryption callbacks for the provided data base, to perform automatically cryptographic operations on all models that contain a value of type EncryptedString or a file tagged with 'encrypt:"true"'. +// Register registers encryption and decryption callbacks for the provided data base, to perform automatically cryptographic operations on all models that contain a field tagged with 'encrypt:"true"'. func Register(db *gorm.DB, c *DBCipher) error { encryptCb := func(db *gorm.DB) { db.AddError(encryptModel(c, db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. diff --git a/pkg/dbcrypt/dbcipher.go b/pkg/dbcrypt/dbcipher.go index 9f43302..a98c6e9 100644 --- a/pkg/dbcrypt/dbcipher.go +++ b/pkg/dbcrypt/dbcipher.go @@ -16,7 +16,7 @@ const prefixSeparator = ":" type Config struct { // Default version of the cryptographic algorithm. Useful for testing older historical implementations. Leave empty to use the most recent version. // - // - use for v2 version of the cryptographic algorithm + // - use "" for latest version of the cryptographic algorithm // - use "v2" for v2 version of the cryptographic algorithm // - use "v1" for v1 version of the cryptographic algorithm Version string @@ -28,6 +28,23 @@ type Config struct { PasswordSalt string } +// Validate validates the provided config. +func (conf Config) Validate() error { + if conf.Version != "" && conf.Version != "v1" && conf.Version != "v2" { + return fmt.Errorf("invalid db cipher version %q", conf.Version) + } + if conf.Password == "" { + return errors.New("db password is empty") + } + if conf.PasswordSalt == "" { + return errors.New("db password salt is empty") + } + if len(conf.PasswordSalt) < 32 { + return errors.New("db password salt is too short") + } + return nil +} + // DBCipher is cipher designed to perform validated encryption and decryption on database values. type DBCipher struct { encryptionCipher dbCipher @@ -36,14 +53,8 @@ type DBCipher struct { // NewDBCipher creates a new instance of DBCipher based on the provided Config. func NewDBCipher(conf Config) (*DBCipher, error) { - if conf.Password == "" { - return nil, errors.New("db password is empty") - } - if conf.PasswordSalt == "" { - return nil, errors.New("db password salt is empty") - } - if len(conf.PasswordSalt) < 32 { - return nil, errors.New("db password salt is too short") + if err := conf.Validate(); err != nil { + return nil, err } c := &DBCipher{} if err := c.registerCiphers(conf); err != nil { @@ -53,26 +64,26 @@ func NewDBCipher(conf Config) (*DBCipher, error) { } func (c *DBCipher) registerCiphers(conf Config) error { - v2, err := newDbCipherV2(conf) + v1, err := newDbCipherV1(conf) if err != nil { return err } - v1, err := newDbCipherV1(conf) + v2, err := newDbCipherV2(conf) if err != nil { return err } c.decryptionCiphers = map[string]dbCipher{ - v2.Prefix(): v2, v1.Prefix(): v1, + v2.Prefix(): v2, } switch conf.Version { - case "", "v2": - c.encryptionCipher = v2 case "v1": c.encryptionCipher = v1 + case "v2", "": + c.encryptionCipher = v2 default: - return fmt.Errorf("invalid db cipher version %q", conf.Version) + panic("invalid db cipher version") // valid config should never reach this code } return nil } @@ -83,7 +94,11 @@ func (c *DBCipher) Encrypt(plaintext []byte) ([]byte, error) { if err != nil { return nil, err } - return append([]byte(c.encryptionCipher.Prefix()+prefixSeparator), ciphertext...), nil + ciphertextWithPrefix := bytes.NewBuffer(nil) + ciphertextWithPrefix.WriteString(c.encryptionCipher.Prefix()) + ciphertextWithPrefix.WriteString(prefixSeparator) + ciphertextWithPrefix.Write(ciphertext) + return ciphertextWithPrefix.Bytes(), nil } // Decrypt decrypts the provided bytes with DBCipher. diff --git a/pkg/dbcrypt/gorm_test.go b/pkg/dbcrypt/gorm_test.go index a246f6e..199b253 100644 --- a/pkg/dbcrypt/gorm_test.go +++ b/pkg/dbcrypt/gorm_test.go @@ -23,17 +23,17 @@ func newTestDb[T any](t *testing.T) *gorm.DB { return db } -func TestGormCreateReadWithTag(t *testing.T) { +func TestGormCreateRead(t *testing.T) { type Model struct { - ID uint `gorm:"primarykey"` - Protected string `encrypt:"true"` + ID uint `gorm:"primarykey"` + Secret string `encrypt:"true"` } db := newTestDb[Model](t) crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, crypt)) - givenData := Model{Protected: "aaa"} + givenData := Model{Secret: "aaa"} require.NoError(t, db.Create(&givenData).Error) gotData := Model{} @@ -41,95 +41,38 @@ func TestGormCreateReadWithTag(t *testing.T) { require.Equal(t, givenData, gotData) } -func TestGormCreateReadWithType(t *testing.T) { +func TestGormCreateReadRaw(t *testing.T) { type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString + ID uint `gorm:"primarykey"` + Secret string `encrypt:"true"` } db := newTestDb[Model](t) crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, crypt)) - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - require.Equal(t, givenData, gotData) -} - -func TestGormCreateReadWithTypeNonPointerValue(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: *dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - require.Equal(t, givenData, gotData) -} - -func TestGormCreateReadRawWithTag(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected string `encrypt:"true"` - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: "aaa"} - require.NoError(t, db.Create(&givenData).Error) - - gotData := Model{} - require.NoError(t, db.Raw(`SELECT * FROM models LIMIT 1`).Scan(&gotData).Error) - require.NotEqual(t, givenData, gotData) -} - -func TestGormCreateReadRawWithType(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + givenData := Model{Secret: "aaa"} require.NoError(t, db.Create(&givenData).Error) gotData := Model{} require.NoError(t, db.Raw(`SELECT * FROM models LIMIT 1`).Scan(&gotData).Error) require.NotEqual(t, givenData, gotData) - givenDataEncrypted, _ := givenData.Protected.Encrypted() - gotDataEncrypted, _ := gotData.Protected.Encrypted() - require.Equal(t, givenDataEncrypted, gotDataEncrypted) } -func TestGormCreateUpdateReadWithType(t *testing.T) { +func TestGormCreateUpdateRead(t *testing.T) { type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString + ID uint `gorm:"primarykey"` + Secret string `encrypt:"true"` } db := newTestDb[Model](t) crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, crypt)) - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + givenData := Model{Secret: "aaa"} require.NoError(t, db.Create(&givenData).Error) - updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} + updatedData := Model{ID: givenData.ID, Secret: "bbb"} require.NoError(t, db.Updates(&updatedData).Error) gotData := Model{} @@ -137,51 +80,10 @@ func TestGormCreateUpdateReadWithType(t *testing.T) { require.Equal(t, updatedData, gotData) } -func TestGormCreateColumnUpdateReadWithType(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString - } - db := newTestDb[Model](t) - crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, crypt)) - - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} - require.NoError(t, db.Create(&givenData).Error) - - updatedData := Model{ID: givenData.ID, Protected: dbcrypt.NewEncryptedString("bbb")} - require.NoError(t, db.Model(&Model{}).Where("id = ?", updatedData.ID).Update("protected", dbcrypt.NewEncryptedString("bbb")).Error) - - gotData := Model{} - require.NoError(t, db.First(&gotData).Error) - gotData.Protected.ClearEncrypted() - require.Equal(t, updatedData, gotData) -} - -func TestGormMixDBCryptInstancesWithTag(t *testing.T) { - type Model struct { - ID uint `gorm:"primarykey"` - Protected string `encrypt:"true"` - } - db := newTestDb[Model](t) - cryptFirst, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - cryptSecond, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) - require.NoError(t, err) - require.NoError(t, dbcrypt.Register(db, cryptFirst)) - - givenData := Model{Protected: "aaa"} - require.NoError(t, db.Create(&givenData).Error) - - require.NoError(t, dbcrypt.Register(db, cryptSecond)) - require.Error(t, db.First(&Model{}).Error) -} - -func TestGormMixDBCryptInstancesWithType(t *testing.T) { +func TestGormMixDBCryptInstances(t *testing.T) { type Model struct { - ID uint `gorm:"primarykey"` - Protected *dbcrypt.EncryptedString + ID uint `gorm:"primarykey"` + Secret string `encrypt:"true"` } db := newTestDb[Model](t) cryptFirst, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) @@ -190,7 +92,7 @@ func TestGormMixDBCryptInstancesWithType(t *testing.T) { require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, cryptFirst)) - givenData := Model{Protected: dbcrypt.NewEncryptedString("aaa")} + givenData := Model{Secret: "aaa"} require.NoError(t, db.Create(&givenData).Error) require.NoError(t, dbcrypt.Register(db, cryptSecond)) diff --git a/pkg/dbcrypt/value.go b/pkg/dbcrypt/value.go deleted file mode 100644 index 46f01bc..0000000 --- a/pkg/dbcrypt/value.go +++ /dev/null @@ -1,98 +0,0 @@ -package dbcrypt - -import ( - "database/sql/driver" - "errors" -) - -// EncryptedString is a wrapper around string that indicates that the value should be encrypted wile stored. -type EncryptedString struct { - encrypted string - decrypted string -} - -// NewEncryptedString creates a new EncryptedString based on plaintext data. Returned value, until encrypted, will miss associated ciphertext value. -func NewEncryptedString(dec string) *EncryptedString { - return &EncryptedString{decrypted: dec} -} - -// DecryptEncryptedString creates a new EncryptedString based on ciphertext data. It automatically decrypts it using the provided DBCipher, so both plaintext and ciphertext values are available. -func DecryptEncryptedString(c *DBCipher, enc string) (*EncryptedString, error) { - es := &EncryptedString{encrypted: enc} - if err := es.decrypt(c); err != nil { - return nil, err - } - return es, nil -} - -// Scan unmarshal encrypted stored value into EncryptedString. -func (es *EncryptedString) Scan(v any) error { - enc, ok := v.(string) - if !ok { - return errors.New("failed to unmarshal encrypted string value") - } - es.encrypted, es.decrypted = enc, "" - return nil -} - -// Value returns encrypted value for storing. -func (es EncryptedString) Value() (driver.Value, error) { - enc, ok := es.Encrypted() - if !ok { - return nil, errors.New("cannot store string value: encryption required") - } - if enc == "" { - return nil, nil - } - return enc, nil -} - -// Encrypted returns ciphertext (encrypted) value of EncryptedString and true, if encrypted value is available. Otherwise it return an empty string and false. -func (es *EncryptedString) Encrypted() (string, bool) { - if es == nil { - return "", true - } - has := es.encrypted != "" || es.decrypted == "" - return es.encrypted, has -} - -// Encrypt generates a new ciphertext value based on plaintext value using the provided DBCipher. -func (es *EncryptedString) Encrypt(c *DBCipher) error { - enc, err := c.Encrypt([]byte(es.decrypted)) - if err != nil { - return err - } - es.encrypted = string(enc) - return nil -} - -// ClearEncrypted removes associated encrypted value. -func (es *EncryptedString) ClearEncrypted() { - es.encrypted = "" -} - -func (es *EncryptedString) decrypt(c *DBCipher) error { - enc, ok := es.Encrypted() - if !ok || enc == "" { - return nil - } - dec, err := c.Decrypt([]byte(enc)) - if err != nil { - return err - } - es.decrypted = string(dec) - return nil -} - -// Get returns plaintext (decrypted) value of EncryptedString. -func (es *EncryptedString) Get() string { - if es == nil { - return "" - } - return es.decrypted -} - -// Set sets plaintext (decrypted) value of EncryptedString. -func (es *EncryptedString) Set(to string) { - es.encrypted, es.decrypted = "", to -} From 85bed572c3d854929315c2c2ba0cfaf8f2ec412b Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Fri, 21 Nov 2025 15:08:20 +0100 Subject: [PATCH 04/13] change: refine new implementation of dbcrypt package --- pkg/dbcrypt/gorm_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/dbcrypt/gorm_test.go b/pkg/dbcrypt/gorm_test.go index 199b253..7ddff27 100644 --- a/pkg/dbcrypt/gorm_test.go +++ b/pkg/dbcrypt/gorm_test.go @@ -56,7 +56,8 @@ func TestGormCreateReadRaw(t *testing.T) { gotData := Model{} require.NoError(t, db.Raw(`SELECT * FROM models LIMIT 1`).Scan(&gotData).Error) - require.NotEqual(t, givenData, gotData) + require.Equal(t, givenData.ID, gotData.ID) + require.NotEqual(t, givenData.Secret, gotData.Secret) } func TestGormCreateUpdateRead(t *testing.T) { From 9875c17d521ed15c80c0af7b9ee6fcef165141d9 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Mon, 24 Nov 2025 12:26:08 +0100 Subject: [PATCH 05/13] change: add slices supprt in models for dbcrypt package --- pkg/dbcrypt/crypto.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/crypto.go index 1a65521..88abce5 100644 --- a/pkg/dbcrypt/crypto.go +++ b/pkg/dbcrypt/crypto.go @@ -34,6 +34,9 @@ func encryptModel(c *DBCipher, plaintext any) error { if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { return encryptRecursive(c, value) } + if value.Kind() == reflect.Slice { + return encryptRecursive(c, value) + } if value.Kind() == reflect.Map { return encryptRecursive(c, value) } @@ -62,6 +65,13 @@ func encryptRecursive(c *DBCipher, plaintext reflect.Value) error { } } } + if plaintext.Kind() == reflect.Slice { + for i, v := range plaintext.Seq2() { + if err := encryptRecursive(c, v); err != nil { + return fmt.Errorf("list item #%d: %w", i.Int(), err) + } + } + } if plaintext.Kind() == reflect.Map { for k, v := range plaintext.Seq2() { if err := encryptRecursive(c, v); err != nil { @@ -93,6 +103,9 @@ func decryptModel(c *DBCipher, ciphertext any) error { if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { return decryptRecursive(c, value) } + if value.Kind() == reflect.Slice { + return decryptRecursive(c, value) + } if value.Kind() == reflect.Map { return decryptRecursive(c, value) } @@ -121,6 +134,13 @@ func decryptRecursive(c *DBCipher, ciphertext reflect.Value) error { } } } + if ciphertext.Kind() == reflect.Slice { + for i, v := range ciphertext.Seq2() { + if err := decryptRecursive(c, v); err != nil { + return fmt.Errorf("list item #%d: %w", i.Int(), err) + } + } + } if ciphertext.Kind() == reflect.Map { for k, v := range ciphertext.Seq2() { if err := decryptRecursive(c, v); err != nil { From b3ced4c8e6add75b04dcc24b98898e0f969d04b6 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Mon, 24 Nov 2025 13:44:16 +0100 Subject: [PATCH 06/13] change: refactor model travelsal in dbcrypt package --- pkg/dbcrypt/crypto.go | 71 +++++++++++++++++++++------------------- pkg/dbcrypt/gorm_test.go | 25 +++++++++++--- 2 files changed, 57 insertions(+), 39 deletions(-) diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/crypto.go index 88abce5..3ef45dc 100644 --- a/pkg/dbcrypt/crypto.go +++ b/pkg/dbcrypt/crypto.go @@ -29,28 +29,39 @@ func parseEncryptStructFieldTag(sf reflect.StructField) (bool, error) { return true, nil } -func encryptModel(c *DBCipher, plaintext any) error { - value := reflect.ValueOf(plaintext) - if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { - return encryptRecursive(c, value) - } - if value.Kind() == reflect.Slice { - return encryptRecursive(c, value) +func modelValue(model any) reflect.Value { + value := reflect.ValueOf(model) + for { + switch value.Kind() { + case reflect.Pointer, reflect.Interface: + if value.IsNil() { + return value + } + value = value.Elem() + default: + return value + } } - if value.Kind() == reflect.Map { +} + +func encryptModel(c *DBCipher, plaintext any) error { + value := modelValue(plaintext) + switch value.Kind() { + case reflect.Struct, reflect.Slice, reflect.Map: return encryptRecursive(c, value) + default: + return errors.New("invalid value provided for encryption") } - return errors.New("invalid value provided for encryption") } func encryptRecursive(c *DBCipher, plaintext reflect.Value) error { - if plaintext.Kind() == reflect.Pointer || plaintext.Kind() == reflect.Interface { + switch plaintext.Kind() { + case reflect.Pointer, reflect.Interface: if plaintext.IsNil() { return nil } return encryptRecursive(c, plaintext.Elem()) - } - if plaintext.Kind() == reflect.Struct { + case reflect.Struct: typ := plaintext.Type() for i := 0; i < typ.NumField(); i++ { fTyp := typ.Field(i) @@ -64,15 +75,13 @@ func encryptRecursive(c *DBCipher, plaintext reflect.Value) error { return fmt.Errorf("field %q: %w", fTyp.Name, err) } } - } - if plaintext.Kind() == reflect.Slice { + case reflect.Slice: for i, v := range plaintext.Seq2() { if err := encryptRecursive(c, v); err != nil { return fmt.Errorf("list item #%d: %w", i.Int(), err) } } - } - if plaintext.Kind() == reflect.Map { + case reflect.Map: for k, v := range plaintext.Seq2() { if err := encryptRecursive(c, v); err != nil { return fmt.Errorf("map key %q: %w", k.String(), err) @@ -99,27 +108,23 @@ func encryptFieldBasedOnTag(c *DBCipher, sf reflect.StructField, val reflect.Val } func decryptModel(c *DBCipher, ciphertext any) error { - value := reflect.ValueOf(ciphertext) - if value.Kind() == reflect.Pointer && value.Type().Elem().Kind() == reflect.Struct { + value := modelValue(ciphertext) + switch value.Kind() { + case reflect.Struct, reflect.Slice, reflect.Map: return decryptRecursive(c, value) + default: + return errors.New("invalid value provided for decryption") } - if value.Kind() == reflect.Slice { - return decryptRecursive(c, value) - } - if value.Kind() == reflect.Map { - return decryptRecursive(c, value) - } - return errors.New("invalid value provided for decryption") } func decryptRecursive(c *DBCipher, ciphertext reflect.Value) error { - if ciphertext.Kind() == reflect.Pointer || ciphertext.Kind() == reflect.Interface { + switch ciphertext.Kind() { + case reflect.Pointer, reflect.Interface: if ciphertext.IsNil() { return nil } return decryptRecursive(c, ciphertext.Elem()) - } - if ciphertext.Kind() == reflect.Struct { + case reflect.Struct: typ := ciphertext.Type() for i := 0; i < typ.NumField(); i++ { fTyp := typ.Field(i) @@ -133,15 +138,13 @@ func decryptRecursive(c *DBCipher, ciphertext reflect.Value) error { return fmt.Errorf("field %q: %w", fTyp.Name, err) } } - } - if ciphertext.Kind() == reflect.Slice { + case reflect.Slice: for i, v := range ciphertext.Seq2() { if err := decryptRecursive(c, v); err != nil { return fmt.Errorf("list item #%d: %w", i.Int(), err) } } - } - if ciphertext.Kind() == reflect.Map { + case reflect.Map: for k, v := range ciphertext.Seq2() { if err := decryptRecursive(c, v); err != nil { return fmt.Errorf("map key %q: %w", k.String(), err) @@ -170,10 +173,10 @@ func decryptFieldBasedOnTag(c *DBCipher, sf reflect.StructField, val reflect.Val // Register registers encryption and decryption callbacks for the provided data base, to perform automatically cryptographic operations on all models that contain a field tagged with 'encrypt:"true"'. func Register(db *gorm.DB, c *DBCipher) error { encryptCb := func(db *gorm.DB) { - db.AddError(encryptModel(c, db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. + db.AddError(encryptModel(c, &db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. } decryptCb := func(db *gorm.DB) { - db.AddError(decryptModel(c, db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. + db.AddError(decryptModel(c, &db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. } if err := db.Callback(). diff --git a/pkg/dbcrypt/gorm_test.go b/pkg/dbcrypt/gorm_test.go index 7ddff27..b8e3100 100644 --- a/pkg/dbcrypt/gorm_test.go +++ b/pkg/dbcrypt/gorm_test.go @@ -29,7 +29,10 @@ func TestGormCreateRead(t *testing.T) { Secret string `encrypt:"true"` } db := newTestDb[Model](t) - crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{ + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }) require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, crypt)) @@ -47,7 +50,10 @@ func TestGormCreateReadRaw(t *testing.T) { Secret string `encrypt:"true"` } db := newTestDb[Model](t) - crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{ + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }) require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, crypt)) @@ -66,7 +72,10 @@ func TestGormCreateUpdateRead(t *testing.T) { Secret string `encrypt:"true"` } db := newTestDb[Model](t) - crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + crypt, err := dbcrypt.NewDBCipher(dbcrypt.Config{ + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }) require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, crypt)) @@ -87,9 +96,15 @@ func TestGormMixDBCryptInstances(t *testing.T) { Secret string `encrypt:"true"` } db := newTestDb[Model](t) - cryptFirst, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + cryptFirst, err := dbcrypt.NewDBCipher(dbcrypt.Config{ + Password: "encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }) require.NoError(t, err) - cryptSecond, err := dbcrypt.NewDBCipher(dbcrypt.Config{Password: "other-encryption-password", PasswordSalt: "encryption-password-salt-0123456"}) + cryptSecond, err := dbcrypt.NewDBCipher(dbcrypt.Config{ + Password: "other-encryption-password", + PasswordSalt: "encryption-password-salt-0123456", + }) require.NoError(t, err) require.NoError(t, dbcrypt.Register(db, cryptFirst)) From f57659df7d10377b5816ab343de13b2e039d3669 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Mon, 24 Nov 2025 14:15:40 +0100 Subject: [PATCH 07/13] change: imporve error messages in dbcrypt package --- pkg/dbcrypt/crypto.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/crypto.go index 3ef45dc..bc135b2 100644 --- a/pkg/dbcrypt/crypto.go +++ b/pkg/dbcrypt/crypto.go @@ -50,7 +50,7 @@ func encryptModel(c *DBCipher, plaintext any) error { case reflect.Struct, reflect.Slice, reflect.Map: return encryptRecursive(c, value) default: - return errors.New("invalid value provided for encryption") + return fmt.Errorf("invalid %s value provided for encryption", value.Kind().String()) } } @@ -113,7 +113,7 @@ func decryptModel(c *DBCipher, ciphertext any) error { case reflect.Struct, reflect.Slice, reflect.Map: return decryptRecursive(c, value) default: - return errors.New("invalid value provided for decryption") + return fmt.Errorf("invalid %s value provided for decryption", value.Kind().String()) } } From 134efa80f4ebd3518e364ac7f82b5f65e09f3913 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Mon, 24 Nov 2025 14:34:22 +0100 Subject: [PATCH 08/13] change: imporve database model handling in dbcrypt package --- pkg/dbcrypt/crypto.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/crypto.go index bc135b2..8eb3ff2 100644 --- a/pkg/dbcrypt/crypto.go +++ b/pkg/dbcrypt/crypto.go @@ -50,7 +50,7 @@ func encryptModel(c *DBCipher, plaintext any) error { case reflect.Struct, reflect.Slice, reflect.Map: return encryptRecursive(c, value) default: - return fmt.Errorf("invalid %s value provided for encryption", value.Kind().String()) + return nil } } @@ -113,7 +113,7 @@ func decryptModel(c *DBCipher, ciphertext any) error { case reflect.Struct, reflect.Slice, reflect.Map: return decryptRecursive(c, value) default: - return fmt.Errorf("invalid %s value provided for decryption", value.Kind().String()) + return nil } } From 25bc9391929ce34be567cc17b5ff9e219c088d77 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Mon, 24 Nov 2025 15:12:07 +0100 Subject: [PATCH 09/13] change: imporve database model handling in dbcrypt package --- pkg/dbcrypt/dbcipher.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/dbcrypt/dbcipher.go b/pkg/dbcrypt/dbcipher.go index a98c6e9..900a2cf 100644 --- a/pkg/dbcrypt/dbcipher.go +++ b/pkg/dbcrypt/dbcipher.go @@ -103,6 +103,9 @@ func (c *DBCipher) Encrypt(plaintext []byte) ([]byte, error) { // Decrypt decrypts the provided bytes with DBCipher. func (c *DBCipher) Decrypt(ciphertextWithPrefix []byte) ([]byte, error) { + if len(ciphertextWithPrefix) == 0 { + return nil, nil + } prefix, ciphertext, hasSeparator := bytes.Cut(ciphertextWithPrefix, []byte(prefixSeparator)) if !hasSeparator { return nil, errors.New("invalid encrypted value format") From c3b8d26c0991b83bb850e7b6c6563195620a753d Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Tue, 25 Nov 2025 13:47:43 +0100 Subject: [PATCH 10/13] change: refine new implementation of dbcrypt package --- pkg/dbcrypt/README.md | 2 +- pkg/dbcrypt/cipher.go | 121 +++++++----------- pkg/dbcrypt/cipher_spec.go | 106 +++++++++++++++ pkg/dbcrypt/dbcipher.go | 66 ++++------ .../{cipher_test.go => dbcipher_test.go} | 0 pkg/dbcrypt/{crypto.go => gorm.go} | 4 +- 6 files changed, 186 insertions(+), 113 deletions(-) create mode 100644 pkg/dbcrypt/cipher_spec.go rename pkg/dbcrypt/{cipher_test.go => dbcipher_test.go} (100%) rename pkg/dbcrypt/{crypto.go => gorm.go} (94%) diff --git a/pkg/dbcrypt/README.md b/pkg/dbcrypt/README.md index 11d51dd..c5af58e 100644 --- a/pkg/dbcrypt/README.md +++ b/pkg/dbcrypt/README.md @@ -37,7 +37,7 @@ func main() { } dbcrypt.Register(db, cipher) - personWrite := &Person{PasswordField: dbcrypt.NewEncryptedString("secret")} + personWrite := &Person{PasswordField: "secret"} if err := db.Create(personWrite).Error; err != nil { log.Fatalf("Error %v", err) } diff --git a/pkg/dbcrypt/cipher.go b/pkg/dbcrypt/cipher.go index bf69dff..9783918 100644 --- a/pkg/dbcrypt/cipher.go +++ b/pkg/dbcrypt/cipher.go @@ -7,81 +7,66 @@ package dbcrypt import ( "crypto/aes" "crypto/cipher" - "crypto/rand" "encoding/base64" "encoding/hex" "fmt" - "io" "golang.org/x/crypto/argon2" ) type dbCipher interface { - Prefix() string Encrypt(plaintext []byte) ([]byte, error) Decrypt(ciphertext []byte) ([]byte, error) } -type dbCipherV1 struct { +type dbCipherGcmAes struct { key []byte } -func newDbCipherV1(conf Config) (dbCipher, error) { +func newDbCipherGcmAes(key []byte) dbCipher { + return dbCipherGcmAes{key: key} +} + +func newDbCipherGcmAesWithoutKdf(password, passwordSalt string) dbCipher { // Historically "v1" uses key truncation to 32 bytes. It needs to be preserved for backward compatibility. - key := []byte(conf.Password + conf.PasswordSalt)[:32] - return dbCipherV1{key: key}, nil + key := make([]byte, 32) + copy(key, []byte(password+passwordSalt)) + return newDbCipherGcmAes(key) } -func (c dbCipherV1) Prefix() string { - return "ENC" +func newDbCipherGcmAesWithArgon2idKdf(password, passwordSalt string) dbCipher { + // "v2" uses proper KDF (argon2id) to get the key. + key := argon2.IDKey([]byte(password), []byte(passwordSalt), 1, 64*1024, 4, 32) + return newDbCipherGcmAes(key) } -func (c dbCipherV1) Encrypt(plaintext []byte) ([]byte, error) { +func (c dbCipherGcmAes) Encrypt(plaintext []byte) ([]byte, error) { block, err := aes.NewCipher(c.key) if err != nil { - return nil, err + return nil, fmt.Errorf("error creating AES cipher: %w", err) } - gcm, err := cipher.NewGCM(block) + gcm, err := cipher.NewGCMWithRandomNonce(block) if err != nil { - return nil, err - } - - iv := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, iv); err != nil { - return nil, err + return nil, fmt.Errorf("error encrypting plaintext: %w", err) } - ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil) - ciphertextWithIv := append(iv, ciphertext...) - encoded := hex.AppendEncode(nil, ciphertextWithIv) - return encoded, nil + ciphertext := gcm.Seal(nil, nil, []byte(plaintext), nil) + return ciphertext, nil } -func (c dbCipherV1) Decrypt(encoded []byte) ([]byte, error) { - ciphertextWithIv, err := hex.AppendDecode(nil, encoded) - if err != nil { - return nil, fmt.Errorf("error decoding ciphertext: %w", err) - } - - if len(ciphertextWithIv) < aes.BlockSize+1 { - return nil, fmt.Errorf("ciphertext too short") - } - +func (c dbCipherGcmAes) Decrypt(ciphertext []byte) ([]byte, error) { block, err := aes.NewCipher(c.key) if err != nil { return nil, fmt.Errorf("error creating AES cipher: %w", err) } - gcm, err := cipher.NewGCM(block) + gcm, err := cipher.NewGCMWithRandomNonce(block) if err != nil { - return nil, err + return nil, fmt.Errorf("error decrypting ciphertext: %w", err) } - iv := ciphertextWithIv[:gcm.NonceSize()] - ciphertext := ciphertextWithIv[gcm.NonceSize():] - - plaintext, err := gcm.Open(nil, iv, ciphertext, nil) + plaintext, err := gcm.Open(nil, nil, ciphertext, nil) if err != nil { return nil, fmt.Errorf("error decrypting ciphertext: %w", err) } @@ -89,60 +74,52 @@ func (c dbCipherV1) Decrypt(encoded []byte) ([]byte, error) { return plaintext, nil } -type dbCipherV2 struct { - key []byte -} - -func newDbCipherV2(conf Config) (dbCipher, error) { - // "v2" uses proper KDF (argon2id) to get the key - key := argon2.IDKey([]byte(conf.Password), []byte(conf.PasswordSalt), 1, 64*1024, 4, 32) - return dbCipherV2{key: key}, nil +type dbCipherHexEncode struct { + impl dbCipher } -func (c dbCipherV2) Prefix() string { - return "ENCV2" +func newDbCipherHexEncode(impl dbCipher) dbCipher { + return dbCipherHexEncode{impl: impl} } -func (c dbCipherV2) Encrypt(plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(c.key) +func (c dbCipherHexEncode) Encrypt(plaintext []byte) ([]byte, error) { + ciphertext, err := c.impl.Encrypt(plaintext) if err != nil { return nil, err } - - gcm, err := cipher.NewGCMWithRandomNonce(block) - if err != nil { - return nil, err - } - - ciphertext := gcm.Seal(nil, nil, []byte(plaintext), nil) - encoded := base64.StdEncoding.AppendEncode(nil, ciphertext) + encoded := hex.AppendEncode(nil, ciphertext) return encoded, nil } -func (c dbCipherV2) Decrypt(encoded []byte) ([]byte, error) { - ciphertext, err := base64.StdEncoding.AppendDecode(nil, encoded) +func (c dbCipherHexEncode) Decrypt(encoded []byte) ([]byte, error) { + ciphertext, err := hex.AppendDecode(nil, encoded) if err != nil { return nil, fmt.Errorf("error decoding ciphertext: %w", err) } + return c.impl.Decrypt(ciphertext) +} - if len(ciphertext) < aes.BlockSize+1 { - return nil, fmt.Errorf("ciphertext too short") - } +type dbCipherBase64Encode struct { + impl dbCipher +} - block, err := aes.NewCipher(c.key) - if err != nil { - return nil, fmt.Errorf("error creating AES cipher: %w", err) - } +func newDbCipherBase64Encode(impl dbCipher) dbCipher { + return dbCipherBase64Encode{impl: impl} +} - gcm, err := cipher.NewGCMWithRandomNonce(block) +func (c dbCipherBase64Encode) Encrypt(plaintext []byte) ([]byte, error) { + ciphertext, err := c.impl.Encrypt(plaintext) if err != nil { return nil, err } + encoded := base64.StdEncoding.AppendEncode(nil, ciphertext) + return encoded, nil +} - plaintext, err := gcm.Open(nil, nil, ciphertext, nil) +func (c dbCipherBase64Encode) Decrypt(encoded []byte) ([]byte, error) { + ciphertext, err := base64.StdEncoding.AppendDecode(nil, encoded) if err != nil { - return nil, fmt.Errorf("error decrypting ciphertext: %w", err) + return nil, fmt.Errorf("error decoding ciphertext: %w", err) } - - return plaintext, nil + return c.impl.Decrypt(ciphertext) } diff --git a/pkg/dbcrypt/cipher_spec.go b/pkg/dbcrypt/cipher_spec.go new file mode 100644 index 0000000..a506057 --- /dev/null +++ b/pkg/dbcrypt/cipher_spec.go @@ -0,0 +1,106 @@ +package dbcrypt + +import ( + "fmt" + "strings" +) + +type cipherSpec struct { + Version string + Prefix string + Cipher dbCipher +} + +func (cs *cipherSpec) Validate() error { + if cs.Version == "" { + return fmt.Errorf("version is missing") + } + if cs.Prefix == "" { + return fmt.Errorf("prefix is missing") + } + if strings.Contains(cs.Prefix, prefixSeparator) { + return fmt.Errorf("prefix cannot contain %q", prefixSeparator) + } + if cs.Cipher == nil { + return fmt.Errorf("cipher is missing") + } + return nil +} + +type ciphersSpec struct { + DefaultVersion string + Ciphers []cipherSpec +} + +func (cs *ciphersSpec) Validate() error { + if cs.DefaultVersion == "" { + return fmt.Errorf("default version is missing") + } + + seenVersions := make(map[string]bool) + seenPrefix := make(map[string]bool) + defaultFound := false + for _, spec := range cs.Ciphers { + if err := spec.Validate(); err != nil { + return fmt.Errorf("cipher spec: %w", err) + } + if seenVersions[spec.Version] { + return fmt.Errorf("duplicate cipher spec version %q", spec.Version) + } + seenVersions[spec.Version] = true + + if seenPrefix[spec.Prefix] { + return fmt.Errorf("duplicate cipher spec prefix %q", spec.Prefix) + } + seenPrefix[spec.Prefix] = true + + if spec.Version == cs.DefaultVersion { + defaultFound = true + } + } + if !defaultFound { + return fmt.Errorf("default version %q not found in cipher specs", cs.DefaultVersion) + } + + return nil +} + +func newCiphersSpec(conf Config) (*ciphersSpec, error) { + cs := &ciphersSpec{ + DefaultVersion: "v2", + Ciphers: []cipherSpec{ // /!\ this list can only be extended, otherwise decryption will break for existing data + { + Version: "v1", + Prefix: "ENC", + Cipher: newDbCipherHexEncode(newDbCipherGcmAesWithoutKdf(conf.Password, conf.PasswordSalt)), + }, + { + Version: "v2", + Prefix: "ENCV2", + Cipher: newDbCipherBase64Encode(newDbCipherGcmAesWithArgon2idKdf(conf.Password, conf.PasswordSalt)), + }, + }, + } + if err := cs.Validate(); err != nil { + return nil, err + } + return cs, nil +} + +func (cs *ciphersSpec) GetByVersion(version string) (*cipherSpec, error) { + for _, spec := range cs.Ciphers { + if spec.Version == version { + return &spec, nil + } + } + return nil, fmt.Errorf("cipher version %q not found", version) +} + +func (cs *ciphersSpec) GetByPrefix(prefix string) (*cipherSpec, error) { + for _, spec := range cs.Ciphers { + if spec.Prefix == prefix { + return &spec, nil + } + } + return nil, fmt.Errorf("cipher prefix %q not found", prefix) +} diff --git a/pkg/dbcrypt/dbcipher.go b/pkg/dbcrypt/dbcipher.go index 900a2cf..130f36b 100644 --- a/pkg/dbcrypt/dbcipher.go +++ b/pkg/dbcrypt/dbcipher.go @@ -16,12 +16,15 @@ const prefixSeparator = ":" type Config struct { // Default version of the cryptographic algorithm. Useful for testing older historical implementations. Leave empty to use the most recent version. // - // - use "" for latest version of the cryptographic algorithm - // - use "v2" for v2 version of the cryptographic algorithm - // - use "v1" for v1 version of the cryptographic algorithm + // Supported values: + // - "": use latest version of the cryptographic algorithm (recommended). + // - "v2": use v2 version of the cryptographic algorithm. + // - "v1": use v1 version of the cryptographic algorithm. + // + // See cipher_spec.go for all versions Version string - // Contains the password used deriving encryption key + // Contains the password used to derive encryption key Password string // Contains the salt for increasing password entropy @@ -30,9 +33,6 @@ type Config struct { // Validate validates the provided config. func (conf Config) Validate() error { - if conf.Version != "" && conf.Version != "v1" && conf.Version != "v2" { - return fmt.Errorf("invalid db cipher version %q", conf.Version) - } if conf.Password == "" { return errors.New("db password is empty") } @@ -47,8 +47,8 @@ func (conf Config) Validate() error { // DBCipher is cipher designed to perform validated encryption and decryption on database values. type DBCipher struct { - encryptionCipher dbCipher - decryptionCiphers map[string]dbCipher + encryptionCipherSpec *cipherSpec + ciphersSpec *ciphersSpec } // NewDBCipher creates a new instance of DBCipher based on the provided Config. @@ -56,46 +56,36 @@ func NewDBCipher(conf Config) (*DBCipher, error) { if err := conf.Validate(); err != nil { return nil, err } - c := &DBCipher{} - if err := c.registerCiphers(conf); err != nil { - return nil, err + spec, err := newCiphersSpec(conf) + if err != nil { + return nil, fmt.Errorf("error creating crypto ciphers spec: %w", err) } - return c, nil -} -func (c *DBCipher) registerCiphers(conf Config) error { - v1, err := newDbCipherV1(conf) - if err != nil { - return err + encryptionVersion := conf.Version + if encryptionVersion == "" { + encryptionVersion = spec.DefaultVersion } - v2, err := newDbCipherV2(conf) + + encryptionCipherSpec, err := spec.GetByVersion(encryptionVersion) if err != nil { - return err + return nil, fmt.Errorf("could not get encryption cipher by version: %w", err) } - c.decryptionCiphers = map[string]dbCipher{ - v1.Prefix(): v1, - v2.Prefix(): v2, - } - switch conf.Version { - case "v1": - c.encryptionCipher = v1 - case "v2", "": - c.encryptionCipher = v2 - default: - panic("invalid db cipher version") // valid config should never reach this code + c := &DBCipher{ + encryptionCipherSpec: encryptionCipherSpec, + ciphersSpec: spec, } - return nil + return c, nil } // Encrypt encrypts the provided bytes with DBCipher. func (c *DBCipher) Encrypt(plaintext []byte) ([]byte, error) { - ciphertext, err := c.encryptionCipher.Encrypt(plaintext) + ciphertext, err := c.encryptionCipherSpec.Cipher.Encrypt(plaintext) if err != nil { return nil, err } ciphertextWithPrefix := bytes.NewBuffer(nil) - ciphertextWithPrefix.WriteString(c.encryptionCipher.Prefix()) + ciphertextWithPrefix.WriteString(c.encryptionCipherSpec.Prefix) ciphertextWithPrefix.WriteString(prefixSeparator) ciphertextWithPrefix.Write(ciphertext) return ciphertextWithPrefix.Bytes(), nil @@ -110,11 +100,11 @@ func (c *DBCipher) Decrypt(ciphertextWithPrefix []byte) ([]byte, error) { if !hasSeparator { return nil, errors.New("invalid encrypted value format") } - cipher := c.decryptionCiphers[string(prefix)] - if cipher == nil { - return nil, errors.New("unknown encrypted value format") + decryptionCipherSpec, err := c.ciphersSpec.GetByPrefix(string(prefix)) + if err != nil { + return nil, fmt.Errorf("unknown encrypted value format: %w", err) } - plaintext, err := cipher.Decrypt(ciphertext) + plaintext, err := decryptionCipherSpec.Cipher.Decrypt(ciphertext) if err != nil { return nil, err } diff --git a/pkg/dbcrypt/cipher_test.go b/pkg/dbcrypt/dbcipher_test.go similarity index 100% rename from pkg/dbcrypt/cipher_test.go rename to pkg/dbcrypt/dbcipher_test.go diff --git a/pkg/dbcrypt/crypto.go b/pkg/dbcrypt/gorm.go similarity index 94% rename from pkg/dbcrypt/crypto.go rename to pkg/dbcrypt/gorm.go index 8eb3ff2..e812f0f 100644 --- a/pkg/dbcrypt/crypto.go +++ b/pkg/dbcrypt/gorm.go @@ -173,10 +173,10 @@ func decryptFieldBasedOnTag(c *DBCipher, sf reflect.StructField, val reflect.Val // Register registers encryption and decryption callbacks for the provided data base, to perform automatically cryptographic operations on all models that contain a field tagged with 'encrypt:"true"'. func Register(db *gorm.DB, c *DBCipher) error { encryptCb := func(db *gorm.DB) { - db.AddError(encryptModel(c, &db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. + _ = db.AddError(encryptModel(c, &db.Statement.Dest)) } decryptCb := func(db *gorm.DB) { - db.AddError(decryptModel(c, &db.Statement.Dest)) //nolint:errcheck // error value returned by AddError can be safely ignored, as it is the same error as db.Error. + _ = db.AddError(decryptModel(c, &db.Statement.Dest)) } if err := db.Callback(). From b1dc4285ac7fd7b0160b657cb9a26df66df0ee36 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Tue, 25 Nov 2025 13:52:04 +0100 Subject: [PATCH 11/13] change: fix failing tests in dbcrypt package --- pkg/dbcrypt/dbcipher_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/dbcrypt/dbcipher_test.go b/pkg/dbcrypt/dbcipher_test.go index ea4afc4..096d53b 100644 --- a/pkg/dbcrypt/dbcipher_test.go +++ b/pkg/dbcrypt/dbcipher_test.go @@ -113,7 +113,7 @@ func TestCipherCreationFailure(t *testing.T) { Password: "encryption-password", PasswordSalt: "encryption-password-salt-0123456", }, - errorShouldContain: "invalid db cipher version", + errorShouldContain: "cipher version \"unknown\" not found", }, { name: "empty-password", From c89312c6994984381dbe213d27d19ea1e639450a Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Tue, 25 Nov 2025 13:58:01 +0100 Subject: [PATCH 12/13] change: add licence headers to dbcrypt package --- pkg/dbcrypt/cipher_spec.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/dbcrypt/cipher_spec.go b/pkg/dbcrypt/cipher_spec.go index a506057..f822e03 100644 --- a/pkg/dbcrypt/cipher_spec.go +++ b/pkg/dbcrypt/cipher_spec.go @@ -1,3 +1,7 @@ +// SPDX-FileCopyrightText: 2025 Greenbone AG +// +// SPDX-License-Identifier: AGPL-3.0-or-later + package dbcrypt import ( From 7a9b6b13a95e0d24e2b1df0ccb9261eab1bffd95 Mon Sep 17 00:00:00 2001 From: Marek Dalewski Date: Tue, 25 Nov 2025 14:18:42 +0100 Subject: [PATCH 13/13] change: remove unneded comments --- pkg/dbcrypt/cipher.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/dbcrypt/cipher.go b/pkg/dbcrypt/cipher.go index 9783918..4f6f4d2 100644 --- a/pkg/dbcrypt/cipher.go +++ b/pkg/dbcrypt/cipher.go @@ -35,7 +35,6 @@ func newDbCipherGcmAesWithoutKdf(password, passwordSalt string) dbCipher { } func newDbCipherGcmAesWithArgon2idKdf(password, passwordSalt string) dbCipher { - // "v2" uses proper KDF (argon2id) to get the key. key := argon2.IDKey([]byte(password), []byte(passwordSalt), 1, 64*1024, 4, 32) return newDbCipherGcmAes(key) }