Skip to content

Commit 3b15aca

Browse files
committed
enforce uniqueness of origins on database level
1 parent 8596854 commit 3b15aca

3 files changed

Lines changed: 43 additions & 9 deletions

File tree

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
CREATE TABLE notification_service.origins (
22
"name" TEXT NOT NULL,
3-
"class" TEXT NOT NULL,
4-
"service_id" TEXT NOT NULL
3+
"class" TEXT NOT NULL CONSTRAINT origins_class_unique UNIQUE,
4+
"service_id" TEXT NOT NULL
55
);
66

77
CREATE INDEX idx_origins_service_id ON notification_service.origins(service_id);

pkg/repository/originrepository/originRepository.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@ import (
1010
"fmt"
1111

1212
"github.com/greenbone/opensight-notification-service/pkg/entities"
13+
"github.com/greenbone/opensight-notification-service/pkg/errs"
1314
"github.com/jmoiron/sqlx"
15+
"github.com/lib/pq"
1416
)
1517

18+
// see https://github.com/lib/pq/blob/3d613208bca2e74f2a20e04126ed30bcb5c4cc27/error.go#L78
19+
const pgErrCodeConflict = "23505"
20+
1621
type OriginRepository struct {
1722
client *sqlx.DB
1823
}
@@ -60,6 +65,12 @@ func (r *OriginRepository) UpsertOrigins(ctx context.Context, serviceID string,
6065
if len(originRows) != 0 {
6166
_, err = tx.NamedExec(createOriginsQuery, originRows)
6267
if err != nil {
68+
var pgErr *pq.Error
69+
if errors.As(err, &pgErr) { // postgres specific error handling
70+
if pgErr.Code == pgErrCodeConflict {
71+
err = &errs.ErrConflict{Message: "duplicate origin class"}
72+
}
73+
}
6374
return fmt.Errorf("could not insert origins: %w", err)
6475
}
6576
}

pkg/repository/originrepository/originRepository_test.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package originrepository
66

77
import (
88
"context"
9+
"errors"
910
"sync"
1011
"testing"
1112

@@ -103,6 +104,23 @@ func Test_UpsertOrigins_ListOrigins(t *testing.T) {
103104
},
104105
wantOrigins: []entities.Origin{},
105106
},
107+
"error on duplicate origin classes": {
108+
inputs: []input{
109+
{
110+
serviceID: "service1",
111+
origins: []entities.Origin{
112+
{Name: "origin1", Class: "classA"},
113+
},
114+
},
115+
{
116+
serviceID: "service2",
117+
origins: []entities.Origin{
118+
{Name: "origin1", Class: "classA"}, // class must be unique across all services
119+
},
120+
},
121+
},
122+
wantErr: true,
123+
},
106124
"error on empty serviceID": {
107125
inputs: []input{
108126
{
@@ -125,16 +143,19 @@ func Test_UpsertOrigins_ListOrigins(t *testing.T) {
125143

126144
ctx := context.Background()
127145

146+
var errs []error
128147
for _, input := range tt.inputs {
129148
err := repo.UpsertOrigins(ctx, input.serviceID, input.origins)
130-
if tt.wantErr {
131-
require.Error(t, err)
132-
} else {
133-
require.NoError(t, err)
134-
}
149+
errs = append(errs, err)
150+
}
151+
err = errors.Join(errs...)
152+
if tt.wantErr {
153+
require.Error(t, err)
154+
return
135155
}
136156

137157
// if all operarions were successful, verify final state
158+
require.NoError(t, err)
138159
gotOrigins, err := repo.ListOrigins(ctx)
139160
require.NoError(t, err)
140161
assert.ElementsMatch(t, tt.wantOrigins, gotOrigins) // order so far not guaranteed or relevant
@@ -162,7 +183,9 @@ func Test_UpsertOrigins_Concurrency(t *testing.T) {
162183
go func(val int) {
163184
defer wg.Done()
164185
err := repo.UpsertOrigins(context.Background(), serviceID, origins)
165-
assert.NoError(t, err, "failed at iteration %d", val)
186+
if err != nil {
187+
t.Logf("Failed at iteration %d: %v", val, err)
188+
}
166189
}(i)
167190
}
168191
wg.Wait()
@@ -171,5 +194,5 @@ func Test_UpsertOrigins_Concurrency(t *testing.T) {
171194
err = db.QueryRow("SELECT COUNT(*) FROM notification_service.origins WHERE service_id = $1", serviceID).Scan(&count)
172195
require.NoError(t, err)
173196

174-
require.Equal(t, len(origins), count, "Data was duplicated due to race condition!")
197+
require.Equal(t, 2, count, "Data was duplicated due to race condition!")
175198
}

0 commit comments

Comments
 (0)