Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/dbcrypt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func main() {
}
dbcrypt.Register(db, cipher)

personWrite := &Person{PasswordField: dbcrypt.NewEncryptedString("secret")}
personWrite := &Person{PasswordField: "secret"}
if err := db.Create(personWrite).Error; err != nil {
log.Fatalf("Error %v", err)
}
Expand Down
121 changes: 49 additions & 72 deletions pkg/dbcrypt/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,142 +7,119 @@ package dbcrypt
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"

"golang.org/x/crypto/argon2"
)

type dbCipher interface {
Prefix() string
Encrypt(plaintext []byte) ([]byte, error)
Decrypt(ciphertext []byte) ([]byte, error)
}

type dbCipherV1 struct {
type dbCipherGcmAes struct {
key []byte
}

func newDbCipherV1(conf Config) (dbCipher, error) {
func newDbCipherGcmAes(key []byte) dbCipher {
return dbCipherGcmAes{key: key}
}

func newDbCipherGcmAesWithoutKdf(password, passwordSalt string) dbCipher {
// Historically "v1" uses key truncation to 32 bytes. It needs to be preserved for backward compatibility.
key := []byte(conf.Password + conf.PasswordSalt)[:32]
return dbCipherV1{key: key}, nil
key := make([]byte, 32)
copy(key, []byte(password+passwordSalt))
return newDbCipherGcmAes(key)
}

func (c dbCipherV1) Prefix() string {
return "ENC"
func newDbCipherGcmAesWithArgon2idKdf(password, passwordSalt string) dbCipher {
// "v2" uses proper KDF (argon2id) to get the key.
Comment thread
badarghal marked this conversation as resolved.
Outdated
key := argon2.IDKey([]byte(password), []byte(passwordSalt), 1, 64*1024, 4, 32)
return newDbCipherGcmAes(key)
}

func (c dbCipherV1) Encrypt(plaintext []byte) ([]byte, error) {
func (c dbCipherGcmAes) Encrypt(plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(c.key)
if err != nil {
return nil, err
return nil, fmt.Errorf("error creating AES cipher: %w", err)
}

gcm, err := cipher.NewGCM(block)
gcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return nil, err
}

iv := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
return nil, fmt.Errorf("error encrypting plaintext: %w", err)
}

ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil)
ciphertextWithIv := append(iv, ciphertext...)
encoded := hex.AppendEncode(nil, ciphertextWithIv)
return encoded, nil
ciphertext := gcm.Seal(nil, nil, []byte(plaintext), nil)
return ciphertext, nil
}

func (c dbCipherV1) Decrypt(encoded []byte) ([]byte, error) {
ciphertextWithIv, err := hex.AppendDecode(nil, encoded)
if err != nil {
return nil, fmt.Errorf("error decoding ciphertext: %w", err)
}

if len(ciphertextWithIv) < aes.BlockSize+1 {
return nil, fmt.Errorf("ciphertext too short")
}

func (c dbCipherGcmAes) Decrypt(ciphertext []byte) ([]byte, error) {
block, err := aes.NewCipher(c.key)
if err != nil {
return nil, fmt.Errorf("error creating AES cipher: %w", err)
}

gcm, err := cipher.NewGCM(block)
gcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return nil, err
return nil, fmt.Errorf("error decrypting ciphertext: %w", err)
}

iv := ciphertextWithIv[:gcm.NonceSize()]
ciphertext := ciphertextWithIv[gcm.NonceSize():]

plaintext, err := gcm.Open(nil, iv, ciphertext, nil)
plaintext, err := gcm.Open(nil, nil, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("error decrypting ciphertext: %w", err)
}

return plaintext, nil
}

type dbCipherV2 struct {
key []byte
}

func newDbCipherV2(conf Config) (dbCipher, error) {
// "v2" uses proper KDF (argon2id) to get the key
key := argon2.IDKey([]byte(conf.Password), []byte(conf.PasswordSalt), 1, 64*1024, 4, 32)
return dbCipherV2{key: key}, nil
type dbCipherHexEncode struct {
impl dbCipher
}

func (c dbCipherV2) Prefix() string {
return "ENCV2"
func newDbCipherHexEncode(impl dbCipher) dbCipher {
return dbCipherHexEncode{impl: impl}
}

func (c dbCipherV2) Encrypt(plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(c.key)
func (c dbCipherHexEncode) Encrypt(plaintext []byte) ([]byte, error) {
ciphertext, err := c.impl.Encrypt(plaintext)
if err != nil {
return nil, err
}

gcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return nil, err
}

ciphertext := gcm.Seal(nil, nil, []byte(plaintext), nil)
encoded := base64.StdEncoding.AppendEncode(nil, ciphertext)
encoded := hex.AppendEncode(nil, ciphertext)
return encoded, nil
}

func (c dbCipherV2) Decrypt(encoded []byte) ([]byte, error) {
ciphertext, err := base64.StdEncoding.AppendDecode(nil, encoded)
func (c dbCipherHexEncode) Decrypt(encoded []byte) ([]byte, error) {
ciphertext, err := hex.AppendDecode(nil, encoded)
if err != nil {
return nil, fmt.Errorf("error decoding ciphertext: %w", err)
}
return c.impl.Decrypt(ciphertext)
}

if len(ciphertext) < aes.BlockSize+1 {
return nil, fmt.Errorf("ciphertext too short")
}
type dbCipherBase64Encode struct {
impl dbCipher
}

block, err := aes.NewCipher(c.key)
if err != nil {
return nil, fmt.Errorf("error creating AES cipher: %w", err)
}
func newDbCipherBase64Encode(impl dbCipher) dbCipher {
return dbCipherBase64Encode{impl: impl}
}

gcm, err := cipher.NewGCMWithRandomNonce(block)
func (c dbCipherBase64Encode) Encrypt(plaintext []byte) ([]byte, error) {
ciphertext, err := c.impl.Encrypt(plaintext)
if err != nil {
return nil, err
}
encoded := base64.StdEncoding.AppendEncode(nil, ciphertext)
return encoded, nil
}

plaintext, err := gcm.Open(nil, nil, ciphertext, nil)
func (c dbCipherBase64Encode) Decrypt(encoded []byte) ([]byte, error) {
ciphertext, err := base64.StdEncoding.AppendDecode(nil, encoded)
if err != nil {
return nil, fmt.Errorf("error decrypting ciphertext: %w", err)
return nil, fmt.Errorf("error decoding ciphertext: %w", err)
}

return plaintext, nil
return c.impl.Decrypt(ciphertext)
}
106 changes: 106 additions & 0 deletions pkg/dbcrypt/cipher_spec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package dbcrypt

import (
"fmt"
"strings"
)

type cipherSpec struct {
Version string
Prefix string
Cipher dbCipher
}

func (cs *cipherSpec) Validate() error {
if cs.Version == "" {
return fmt.Errorf("version is missing")
}
if cs.Prefix == "" {
return fmt.Errorf("prefix is missing")
}
if strings.Contains(cs.Prefix, prefixSeparator) {
return fmt.Errorf("prefix cannot contain %q", prefixSeparator)
}
if cs.Cipher == nil {
return fmt.Errorf("cipher is missing")
}
return nil
}

type ciphersSpec struct {
DefaultVersion string
Ciphers []cipherSpec
}

func (cs *ciphersSpec) Validate() error {
if cs.DefaultVersion == "" {
return fmt.Errorf("default version is missing")
}

seenVersions := make(map[string]bool)
seenPrefix := make(map[string]bool)
defaultFound := false
for _, spec := range cs.Ciphers {
if err := spec.Validate(); err != nil {
return fmt.Errorf("cipher spec: %w", err)
}
if seenVersions[spec.Version] {
return fmt.Errorf("duplicate cipher spec version %q", spec.Version)
}
seenVersions[spec.Version] = true

if seenPrefix[spec.Prefix] {
return fmt.Errorf("duplicate cipher spec prefix %q", spec.Prefix)
}
seenPrefix[spec.Prefix] = true

if spec.Version == cs.DefaultVersion {
defaultFound = true
}
}
if !defaultFound {
return fmt.Errorf("default version %q not found in cipher specs", cs.DefaultVersion)
}

return nil
}

func newCiphersSpec(conf Config) (*ciphersSpec, error) {
cs := &ciphersSpec{
DefaultVersion: "v2",
Ciphers: []cipherSpec{ // /!\ this list can only be extended, otherwise decryption will break for existing data
{
Version: "v1",
Prefix: "ENC",
Cipher: newDbCipherHexEncode(newDbCipherGcmAesWithoutKdf(conf.Password, conf.PasswordSalt)),
},
{
Version: "v2",
Prefix: "ENCV2",
Cipher: newDbCipherBase64Encode(newDbCipherGcmAesWithArgon2idKdf(conf.Password, conf.PasswordSalt)),
},
},
}
if err := cs.Validate(); err != nil {
return nil, err
}
return cs, nil
}

func (cs *ciphersSpec) GetByVersion(version string) (*cipherSpec, error) {
for _, spec := range cs.Ciphers {
if spec.Version == version {
return &spec, nil
}
}
return nil, fmt.Errorf("cipher version %q not found", version)
}

func (cs *ciphersSpec) GetByPrefix(prefix string) (*cipherSpec, error) {
for _, spec := range cs.Ciphers {
if spec.Prefix == prefix {
return &spec, nil
}
}
return nil, fmt.Errorf("cipher prefix %q not found", prefix)
}
Loading
Loading