Skip to content

Commit c3b8d26

Browse files
author
Marek Dalewski
committed
change: refine new implementation of dbcrypt package
1 parent febb8d6 commit c3b8d26

6 files changed

Lines changed: 186 additions & 113 deletions

File tree

pkg/dbcrypt/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func main() {
3737
}
3838
dbcrypt.Register(db, cipher)
3939

40-
personWrite := &Person{PasswordField: dbcrypt.NewEncryptedString("secret")}
40+
personWrite := &Person{PasswordField: "secret"}
4141
if err := db.Create(personWrite).Error; err != nil {
4242
log.Fatalf("Error %v", err)
4343
}

pkg/dbcrypt/cipher.go

Lines changed: 49 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7,142 +7,119 @@ package dbcrypt
77
import (
88
"crypto/aes"
99
"crypto/cipher"
10-
"crypto/rand"
1110
"encoding/base64"
1211
"encoding/hex"
1312
"fmt"
14-
"io"
1513

1614
"golang.org/x/crypto/argon2"
1715
)
1816

1917
type dbCipher interface {
20-
Prefix() string
2118
Encrypt(plaintext []byte) ([]byte, error)
2219
Decrypt(ciphertext []byte) ([]byte, error)
2320
}
2421

25-
type dbCipherV1 struct {
22+
type dbCipherGcmAes struct {
2623
key []byte
2724
}
2825

29-
func newDbCipherV1(conf Config) (dbCipher, error) {
26+
func newDbCipherGcmAes(key []byte) dbCipher {
27+
return dbCipherGcmAes{key: key}
28+
}
29+
30+
func newDbCipherGcmAesWithoutKdf(password, passwordSalt string) dbCipher {
3031
// Historically "v1" uses key truncation to 32 bytes. It needs to be preserved for backward compatibility.
31-
key := []byte(conf.Password + conf.PasswordSalt)[:32]
32-
return dbCipherV1{key: key}, nil
32+
key := make([]byte, 32)
33+
copy(key, []byte(password+passwordSalt))
34+
return newDbCipherGcmAes(key)
3335
}
3436

35-
func (c dbCipherV1) Prefix() string {
36-
return "ENC"
37+
func newDbCipherGcmAesWithArgon2idKdf(password, passwordSalt string) dbCipher {
38+
// "v2" uses proper KDF (argon2id) to get the key.
39+
key := argon2.IDKey([]byte(password), []byte(passwordSalt), 1, 64*1024, 4, 32)
40+
return newDbCipherGcmAes(key)
3741
}
3842

39-
func (c dbCipherV1) Encrypt(plaintext []byte) ([]byte, error) {
43+
func (c dbCipherGcmAes) Encrypt(plaintext []byte) ([]byte, error) {
4044
block, err := aes.NewCipher(c.key)
4145
if err != nil {
42-
return nil, err
46+
return nil, fmt.Errorf("error creating AES cipher: %w", err)
4347
}
4448

45-
gcm, err := cipher.NewGCM(block)
49+
gcm, err := cipher.NewGCMWithRandomNonce(block)
4650
if err != nil {
47-
return nil, err
48-
}
49-
50-
iv := make([]byte, gcm.NonceSize())
51-
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
52-
return nil, err
51+
return nil, fmt.Errorf("error encrypting plaintext: %w", err)
5352
}
5453

55-
ciphertext := gcm.Seal(nil, iv, []byte(plaintext), nil)
56-
ciphertextWithIv := append(iv, ciphertext...)
57-
encoded := hex.AppendEncode(nil, ciphertextWithIv)
58-
return encoded, nil
54+
ciphertext := gcm.Seal(nil, nil, []byte(plaintext), nil)
55+
return ciphertext, nil
5956
}
6057

61-
func (c dbCipherV1) Decrypt(encoded []byte) ([]byte, error) {
62-
ciphertextWithIv, err := hex.AppendDecode(nil, encoded)
63-
if err != nil {
64-
return nil, fmt.Errorf("error decoding ciphertext: %w", err)
65-
}
66-
67-
if len(ciphertextWithIv) < aes.BlockSize+1 {
68-
return nil, fmt.Errorf("ciphertext too short")
69-
}
70-
58+
func (c dbCipherGcmAes) Decrypt(ciphertext []byte) ([]byte, error) {
7159
block, err := aes.NewCipher(c.key)
7260
if err != nil {
7361
return nil, fmt.Errorf("error creating AES cipher: %w", err)
7462
}
7563

76-
gcm, err := cipher.NewGCM(block)
64+
gcm, err := cipher.NewGCMWithRandomNonce(block)
7765
if err != nil {
78-
return nil, err
66+
return nil, fmt.Errorf("error decrypting ciphertext: %w", err)
7967
}
8068

81-
iv := ciphertextWithIv[:gcm.NonceSize()]
82-
ciphertext := ciphertextWithIv[gcm.NonceSize():]
83-
84-
plaintext, err := gcm.Open(nil, iv, ciphertext, nil)
69+
plaintext, err := gcm.Open(nil, nil, ciphertext, nil)
8570
if err != nil {
8671
return nil, fmt.Errorf("error decrypting ciphertext: %w", err)
8772
}
8873

8974
return plaintext, nil
9075
}
9176

92-
type dbCipherV2 struct {
93-
key []byte
94-
}
95-
96-
func newDbCipherV2(conf Config) (dbCipher, error) {
97-
// "v2" uses proper KDF (argon2id) to get the key
98-
key := argon2.IDKey([]byte(conf.Password), []byte(conf.PasswordSalt), 1, 64*1024, 4, 32)
99-
return dbCipherV2{key: key}, nil
77+
type dbCipherHexEncode struct {
78+
impl dbCipher
10079
}
10180

102-
func (c dbCipherV2) Prefix() string {
103-
return "ENCV2"
81+
func newDbCipherHexEncode(impl dbCipher) dbCipher {
82+
return dbCipherHexEncode{impl: impl}
10483
}
10584

106-
func (c dbCipherV2) Encrypt(plaintext []byte) ([]byte, error) {
107-
block, err := aes.NewCipher(c.key)
85+
func (c dbCipherHexEncode) Encrypt(plaintext []byte) ([]byte, error) {
86+
ciphertext, err := c.impl.Encrypt(plaintext)
10887
if err != nil {
10988
return nil, err
11089
}
111-
112-
gcm, err := cipher.NewGCMWithRandomNonce(block)
113-
if err != nil {
114-
return nil, err
115-
}
116-
117-
ciphertext := gcm.Seal(nil, nil, []byte(plaintext), nil)
118-
encoded := base64.StdEncoding.AppendEncode(nil, ciphertext)
90+
encoded := hex.AppendEncode(nil, ciphertext)
11991
return encoded, nil
12092
}
12193

122-
func (c dbCipherV2) Decrypt(encoded []byte) ([]byte, error) {
123-
ciphertext, err := base64.StdEncoding.AppendDecode(nil, encoded)
94+
func (c dbCipherHexEncode) Decrypt(encoded []byte) ([]byte, error) {
95+
ciphertext, err := hex.AppendDecode(nil, encoded)
12496
if err != nil {
12597
return nil, fmt.Errorf("error decoding ciphertext: %w", err)
12698
}
99+
return c.impl.Decrypt(ciphertext)
100+
}
127101

128-
if len(ciphertext) < aes.BlockSize+1 {
129-
return nil, fmt.Errorf("ciphertext too short")
130-
}
102+
type dbCipherBase64Encode struct {
103+
impl dbCipher
104+
}
131105

132-
block, err := aes.NewCipher(c.key)
133-
if err != nil {
134-
return nil, fmt.Errorf("error creating AES cipher: %w", err)
135-
}
106+
func newDbCipherBase64Encode(impl dbCipher) dbCipher {
107+
return dbCipherBase64Encode{impl: impl}
108+
}
136109

137-
gcm, err := cipher.NewGCMWithRandomNonce(block)
110+
func (c dbCipherBase64Encode) Encrypt(plaintext []byte) ([]byte, error) {
111+
ciphertext, err := c.impl.Encrypt(plaintext)
138112
if err != nil {
139113
return nil, err
140114
}
115+
encoded := base64.StdEncoding.AppendEncode(nil, ciphertext)
116+
return encoded, nil
117+
}
141118

142-
plaintext, err := gcm.Open(nil, nil, ciphertext, nil)
119+
func (c dbCipherBase64Encode) Decrypt(encoded []byte) ([]byte, error) {
120+
ciphertext, err := base64.StdEncoding.AppendDecode(nil, encoded)
143121
if err != nil {
144-
return nil, fmt.Errorf("error decrypting ciphertext: %w", err)
122+
return nil, fmt.Errorf("error decoding ciphertext: %w", err)
145123
}
146-
147-
return plaintext, nil
124+
return c.impl.Decrypt(ciphertext)
148125
}

pkg/dbcrypt/cipher_spec.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package dbcrypt
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
)
7+
8+
type cipherSpec struct {
9+
Version string
10+
Prefix string
11+
Cipher dbCipher
12+
}
13+
14+
func (cs *cipherSpec) Validate() error {
15+
if cs.Version == "" {
16+
return fmt.Errorf("version is missing")
17+
}
18+
if cs.Prefix == "" {
19+
return fmt.Errorf("prefix is missing")
20+
}
21+
if strings.Contains(cs.Prefix, prefixSeparator) {
22+
return fmt.Errorf("prefix cannot contain %q", prefixSeparator)
23+
}
24+
if cs.Cipher == nil {
25+
return fmt.Errorf("cipher is missing")
26+
}
27+
return nil
28+
}
29+
30+
type ciphersSpec struct {
31+
DefaultVersion string
32+
Ciphers []cipherSpec
33+
}
34+
35+
func (cs *ciphersSpec) Validate() error {
36+
if cs.DefaultVersion == "" {
37+
return fmt.Errorf("default version is missing")
38+
}
39+
40+
seenVersions := make(map[string]bool)
41+
seenPrefix := make(map[string]bool)
42+
defaultFound := false
43+
for _, spec := range cs.Ciphers {
44+
if err := spec.Validate(); err != nil {
45+
return fmt.Errorf("cipher spec: %w", err)
46+
}
47+
if seenVersions[spec.Version] {
48+
return fmt.Errorf("duplicate cipher spec version %q", spec.Version)
49+
}
50+
seenVersions[spec.Version] = true
51+
52+
if seenPrefix[spec.Prefix] {
53+
return fmt.Errorf("duplicate cipher spec prefix %q", spec.Prefix)
54+
}
55+
seenPrefix[spec.Prefix] = true
56+
57+
if spec.Version == cs.DefaultVersion {
58+
defaultFound = true
59+
}
60+
}
61+
if !defaultFound {
62+
return fmt.Errorf("default version %q not found in cipher specs", cs.DefaultVersion)
63+
}
64+
65+
return nil
66+
}
67+
68+
func newCiphersSpec(conf Config) (*ciphersSpec, error) {
69+
cs := &ciphersSpec{
70+
DefaultVersion: "v2",
71+
Ciphers: []cipherSpec{ // /!\ this list can only be extended, otherwise decryption will break for existing data
72+
{
73+
Version: "v1",
74+
Prefix: "ENC",
75+
Cipher: newDbCipherHexEncode(newDbCipherGcmAesWithoutKdf(conf.Password, conf.PasswordSalt)),
76+
},
77+
{
78+
Version: "v2",
79+
Prefix: "ENCV2",
80+
Cipher: newDbCipherBase64Encode(newDbCipherGcmAesWithArgon2idKdf(conf.Password, conf.PasswordSalt)),
81+
},
82+
},
83+
}
84+
if err := cs.Validate(); err != nil {
85+
return nil, err
86+
}
87+
return cs, nil
88+
}
89+
90+
func (cs *ciphersSpec) GetByVersion(version string) (*cipherSpec, error) {
91+
for _, spec := range cs.Ciphers {
92+
if spec.Version == version {
93+
return &spec, nil
94+
}
95+
}
96+
return nil, fmt.Errorf("cipher version %q not found", version)
97+
}
98+
99+
func (cs *ciphersSpec) GetByPrefix(prefix string) (*cipherSpec, error) {
100+
for _, spec := range cs.Ciphers {
101+
if spec.Prefix == prefix {
102+
return &spec, nil
103+
}
104+
}
105+
return nil, fmt.Errorf("cipher prefix %q not found", prefix)
106+
}

0 commit comments

Comments
 (0)