Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .drone.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ local bootstrap = '25.02';
local nginx = '1.24.0';
local python = '3.12-slim-bookworm';
local alpine = '3.21';
local visual_diff_skip_build = '2884';
local visual_diff_skip_build = '2914';

local build(arch, testUI) = [{
kind: 'pipeline',
Expand Down
9 changes: 6 additions & 3 deletions backend/auth/authelia.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func NewAuthelia(

func (w *Authelia) RegisterOIDCClient(
id string,
redirectURI string,
redirectURIs []string,
requirePkce bool,
tokenEndpointAuthMethod string,
) (string, error) {
Expand All @@ -137,7 +137,7 @@ func (w *Authelia) RegisterOIDCClient(
err = w.oidc.AddClient(config.OIDCClient{
ID: id,
Secret: secret.Hash,
RedirectURI: redirectURI,
RedirectURIs: redirectURIs,
RequirePkce: requirePkce,
TokenEndpointAuthMethod: tokenEndpointAuthMethod,
})
Expand Down Expand Up @@ -181,7 +181,10 @@ func (w *Authelia) InitConfig() error {
return err
}
for i := range clients {
clients[i].RedirectURI = w.userConfig.Url(clients[i].ID) + clients[i].RedirectURI
appUrl := w.userConfig.Url(clients[i].ID)
for j, redirectURI := range clients[i].RedirectURIs {
clients[i].RedirectURIs[j] = appUrl + redirectURI
}
}
variables := Variables{
Domain: w.userConfig.GetDeviceDomain(),
Expand Down
6 changes: 4 additions & 2 deletions backend/auth/authelia_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ type Client struct {
func TestAutheliaClients(t *testing.T) {
userConfig := &UserConfigStub{domain: "example.com", activated: false}
oidc := &OIDCStub{clients: []config.OIDCClient{
{ID: "app1", Secret: "app1secret", RedirectURI: "/callback1"},
{ID: "app2", Secret: "app2secret", RedirectURI: "/callback2"},
{ID: "app1", Secret: "app1secret", RedirectURIs: []string{"/callback1"}},
{ID: "app2", Secret: "app2secret", RedirectURIs: []string{"/callback2", "/mobile2"}},
}}
outDir := t.TempDir()
secretDir := t.TempDir()
Expand All @@ -162,5 +162,7 @@ func TestAutheliaClients(t *testing.T) {

assert.Equal(t, "app2", gen.IdentityProviders.OIDC.Clients[2].ClientID)
assert.Equal(t, "app2secret", gen.IdentityProviders.OIDC.Clients[2].ClientSecret)
assert.Len(t, gen.IdentityProviders.OIDC.Clients[2].RedirectUris, 2)
assert.Equal(t, "https://app2.example.com/callback2", gen.IdentityProviders.OIDC.Clients[2].RedirectUris[0])
assert.Equal(t, "https://app2.example.com/mobile2", gen.IdentityProviders.OIDC.Clients[2].RedirectUris[1])
}
131 changes: 100 additions & 31 deletions backend/config/migrator.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package config

import (
"context"
"database/sql"
"fmt"
"os"

"github.com/pressly/goose/v3"
)

type Migrator struct {
Expand All @@ -13,53 +16,119 @@ func NewMigrator(db *Db) *Migrator {
return &Migrator{db: db}
}

func (m *Migrator) Migrate() error {
_, err := os.Stat(m.db.File())
if os.IsNotExist(err) {
if err := m.db.Init(); err != nil {
return err
}
func migrations() []*goose.Migration {
return []*goose.Migration{
goose.NewGoMigration(1, &goose.GoFunc{RunTx: createConfigTable}, nil),
goose.NewGoMigration(2, &goose.GoFunc{RunTx: createOidcClientTable}, nil),
goose.NewGoMigration(3, &goose.GoFunc{RunTx: createCustomProxyTable}, nil),
goose.NewGoMigration(4, &goose.GoFunc{RunTx: addCustomProxyHttps}, nil),
goose.NewGoMigration(5, &goose.GoFunc{RunTx: addCustomProxyAuthelia}, nil),
goose.NewGoMigration(6, &goose.GoFunc{RunTx: normalizeOidcRedirectUris}, nil),
}
}

func (m *Migrator) provider() (*goose.Provider, error) {
return goose.NewProvider(goose.DialectSQLite3, m.db.Open(), nil, goose.WithGoMigrations(migrations()...))
}

if err := m.addOidcClientTable(); err != nil {
func (m *Migrator) Migrate() error {
provider, err := m.provider()
if err != nil {
return err
}
if err := m.addCustomProxyTable(); err != nil {
defer provider.Close()
_, err = provider.Up(context.Background())
return err
}

func (m *Migrator) MigrateTo(version int64) error {
provider, err := m.provider()
if err != nil {
return err
}
if err := m.migrateCustomProxyHttps(); err != nil {
defer provider.Close()
_, err = provider.UpTo(context.Background(), version)
return err
}

func createConfigTable(_ context.Context, tx *sql.Tx) error {
_, err := tx.Exec("create table if not exists config (key varchar primary key, value varchar)")
return err
}

func createOidcClientTable(_ context.Context, tx *sql.Tx) error {
_, err := tx.Exec(`create table if not exists oidc_client
(id varchar primary key, secret varchar, redirect_uri varchar, require_pkce integer, token_endpoint_auth_method varchar)`)
return err
}

func createCustomProxyTable(_ context.Context, tx *sql.Tx) error {
_, err := tx.Exec("create table if not exists custom_proxy (name varchar primary key, host varchar, port integer)")
return err
}

func addCustomProxyHttps(ctx context.Context, tx *sql.Tx) error {
return addColumnIfMissing(ctx, tx, "custom_proxy", "https", "integer not null default 0")
}

func addCustomProxyAuthelia(ctx context.Context, tx *sql.Tx) error {
return addColumnIfMissing(ctx, tx, "custom_proxy", "authelia", "integer not null default 0")
}

func normalizeOidcRedirectUris(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(`create table if not exists oidc_redirect_uri
(client_id varchar not null, redirect_uri varchar not null)`)
if err != nil {
return err
}
if err := m.migrateCustomProxyAuthelia(); err != nil {

hasLegacyColumn, err := columnExists(ctx, tx, "oidc_client", "redirect_uri")
if err != nil {
return err
}
return nil
}
if !hasLegacyColumn {
return nil
}

func (m *Migrator) addOidcClientTable() error {
_, err := m.db.Exec(`create table if not exists oidc_client
(id varchar primary key, secret varchar, redirect_uri varchar, require_pkce integer, token_endpoint_auth_method varchar)`)
_, err = tx.Exec(`insert into oidc_redirect_uri (client_id, redirect_uri)
select id, redirect_uri from oidc_client
where redirect_uri != '' and id not in (select client_id from oidc_redirect_uri)`)
if err != nil {
return fmt.Errorf("unable to add oidc_clients: %s", err)
return err
}
return nil

_, err = tx.Exec("ALTER TABLE oidc_client DROP COLUMN redirect_uri")
return err
}

func (m *Migrator) addCustomProxyTable() error {
_, err := m.db.Exec(`create table if not exists custom_proxy
(name varchar primary key, host varchar, port integer)`)
func addColumnIfMissing(ctx context.Context, tx *sql.Tx, table, column, definition string) error {
exists, err := columnExists(ctx, tx, table, column)
if err != nil {
return fmt.Errorf("unable to add custom_proxy table: %s", err)
return err
}
return nil
}

func (m *Migrator) migrateCustomProxyHttps() error {
_, _ = m.db.Exec("ALTER TABLE custom_proxy ADD COLUMN https integer not null default 0")
return nil
if exists {
return nil
}
_, err = tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, definition))
return err
}

func (m *Migrator) migrateCustomProxyAuthelia() error {
_, _ = m.db.Exec("ALTER TABLE custom_proxy ADD COLUMN authelia integer not null default 0")
return nil
func columnExists(ctx context.Context, tx *sql.Tx, table, column string) (bool, error) {
rows, err := tx.QueryContext(ctx, fmt.Sprintf("PRAGMA table_info(%s)", table))
if err != nil {
return false, err
}
defer rows.Close()
for rows.Next() {
var cid, notnull, pk int
var name, columnType string
var defaultValue sql.NullString
if err := rows.Scan(&cid, &name, &columnType, &notnull, &defaultValue, &pk); err != nil {
return false, err
}
if name == column {
return true, nil
}
}
return false, rows.Err()
}
61 changes: 60 additions & 1 deletion backend/config/migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,67 @@ func TestMigrator_CreatesSchemaFromScratch(t *testing.T) {

_, err := db.Exec("INSERT INTO custom_proxy(name, host, port, https, authelia) VALUES ('p', 'h', 1, 0, 1)")
assert.NoError(t, err)
_, err = db.Exec("INSERT INTO oidc_client VALUES ('id', 's', '/cb', 0, 'm')")
_, err = db.Exec("INSERT INTO oidc_client(id, secret, require_pkce, token_endpoint_auth_method) VALUES ('id', 's', 0, 'm')")
assert.NoError(t, err)
_, err = db.Exec("INSERT INTO oidc_redirect_uri(client_id, redirect_uri) VALUES ('id', '/cb')")
assert.NoError(t, err)
}

func TestMigrator_UpgradesPreGooseDbWithoutVersionTable(t *testing.T) {
dbFile := path.Join(t.TempDir(), "db")

pre, err := sql.Open("sqlite", fmt.Sprintf("file:%s?_pragma=busy_timeout(5000)&_pragma=journal_mode(WAL)", dbFile))
assert.NoError(t, err)
_, err = pre.Exec("create table config (key varchar primary key, value varchar)")
assert.NoError(t, err)
_, err = pre.Exec(`create table oidc_client
(id varchar primary key, secret varchar, redirect_uri varchar, require_pkce integer, token_endpoint_auth_method varchar)`)
assert.NoError(t, err)
_, err = pre.Exec("INSERT INTO oidc_client VALUES ('legacy', 's', '/old/callback', 1, 'client_secret_basic')")
assert.NoError(t, err)
_, err = pre.Exec("create table custom_proxy (name varchar primary key, host varchar, port integer, https integer not null default 0, authelia integer not null default 0)")
assert.NoError(t, err)
assert.NoError(t, pre.Close())

assert.NoError(t, NewMigrator(NewDb(dbFile, log.Default())).Migrate())

db := NewDb(dbFile, log.Default())
clients, err := NewOIDC(db).Clients()
assert.NoError(t, err)
assert.Len(t, clients, 1)
assert.Equal(t, []string{"/old/callback"}, clients[0].RedirectURIs)

conn := db.Open()
defer conn.Close()
_, err = conn.Query("select redirect_uri from oidc_client")
assert.Error(t, err, "legacy redirect_uri column must be dropped on a pre-goose db")
}

func TestMigrator_MigratesLegacyRedirectUriIntoTableAndDropsColumn(t *testing.T) {
db := NewDb(path.Join(t.TempDir(), "db"), log.Default())
m := NewMigrator(db)

assert.NoError(t, m.MigrateTo(5))
_, err := db.Exec("INSERT INTO oidc_client(id, secret, redirect_uri, require_pkce, token_endpoint_auth_method) VALUES ('app', 's', '/old/callback', 1, 'client_secret_basic')")
assert.NoError(t, err)

assert.NoError(t, m.Migrate())

conn := db.Open()
defer conn.Close()

rows, err := conn.Query("select client_id, redirect_uri from oidc_redirect_uri")
assert.NoError(t, err)
defer rows.Close()
assert.True(t, rows.Next())
var clientID, redirectURI string
assert.NoError(t, rows.Scan(&clientID, &redirectURI))
assert.Equal(t, "app", clientID)
assert.Equal(t, "/old/callback", redirectURI)
assert.False(t, rows.Next())

_, err = conn.Query("select redirect_uri from oidc_client")
assert.Error(t, err, "legacy redirect_uri column must be dropped from oidc_client")
}

func TestMigrator_IsIdempotent(t *testing.T) {
Expand Down
58 changes: 51 additions & 7 deletions backend/config/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package config
type OIDCClient struct {
ID string
Secret string
RedirectURI string
RedirectURIs []string
RequirePkce bool
TokenEndpointAuthMethod string
}
Expand All @@ -19,7 +19,26 @@ func NewOIDC(db *Db) *OIDC {
func (o *OIDC) Clients() ([]OIDCClient, error) {
db := o.db.Open()
defer db.Close()
rows, err := db.Query("select id, secret, redirect_uri, require_pkce, token_endpoint_auth_method from oidc_client")

uriRows, err := db.Query("select client_id, redirect_uri from oidc_redirect_uri order by rowid")
if err != nil {
return nil, err
}
urisByClient := map[string][]string{}
for uriRows.Next() {
var clientID, redirectURI string
if err := uriRows.Scan(&clientID, &redirectURI); err != nil {
uriRows.Close()
return nil, err
}
urisByClient[clientID] = append(urisByClient[clientID], redirectURI)
}
uriRows.Close()
if err := uriRows.Err(); err != nil {
return nil, err
}

rows, err := db.Query("select id, secret, require_pkce, token_endpoint_auth_method from oidc_client")
if err != nil {
return nil, err
}
Expand All @@ -32,13 +51,13 @@ func (o *OIDC) Clients() ([]OIDCClient, error) {
if err := rows.Scan(
&client.ID,
&client.Secret,
&client.RedirectURI,
&requirePkce,
&client.TokenEndpointAuthMethod,
); err != nil {
return clients, err
}
client.RequirePkce = requirePkce != 0
client.RedirectURIs = urisByClient[client.ID]
clients = append(clients, client)
}
return clients, rows.Err()
Expand All @@ -49,8 +68,33 @@ func (o *OIDC) AddClient(client OIDCClient) error {
if client.RequirePkce {
requirePkce = 1
}
_, err := o.db.Exec("INSERT OR REPLACE INTO oidc_client VALUES (?, ?, ?, ?, ?)",
client.ID, client.Secret, client.RedirectURI, requirePkce, client.TokenEndpointAuthMethod,
)
return err

db := o.db.Open()
defer db.Close()
tx, err := db.Begin()
if err != nil {
return err
}

if _, err := tx.Exec(
"INSERT OR REPLACE INTO oidc_client(id, secret, require_pkce, token_endpoint_auth_method) VALUES (?, ?, ?, ?)",
client.ID, client.Secret, requirePkce, client.TokenEndpointAuthMethod,
); err != nil {
tx.Rollback()
return err
}
if _, err := tx.Exec("DELETE FROM oidc_redirect_uri WHERE client_id = ?", client.ID); err != nil {
tx.Rollback()
return err
}
for _, redirectURI := range client.RedirectURIs {
if _, err := tx.Exec(
"INSERT INTO oidc_redirect_uri(client_id, redirect_uri) VALUES (?, ?)",
client.ID, redirectURI,
); err != nil {
tx.Rollback()
return err
}
}
return tx.Commit()
}
Loading