diff options
Diffstat (limited to 'userapi/storage/accounts/postgres/storage.go')
-rw-r--r-- | userapi/storage/accounts/postgres/storage.go | 49 |
1 files changed, 38 insertions, 11 deletions
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 3933fe5b..c5e74ed1 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "strconv" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -39,25 +40,28 @@ type Database struct { db *sql.DB writer sqlutil.Writer sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - serverName gomatrixserverlib.ServerName - bcryptCost int + accounts accountsStatements + profiles profilesStatements + accountDatas accountDataStatements + threepids threepidStatements + openIDTokens tokenStatements + serverName gomatrixserverlib.ServerName + bcryptCost int + openIDTokenLifetimeMS int64 } // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewDummyWriter(), - bcryptCost: bcryptCost, + serverName: serverName, + db: db, + writer: sqlutil.NewDummyWriter(), + bcryptCost: bcryptCost, + openIDTokenLifetimeMS: openIDTokenLifetimeMS, } // Create tables before executing migrations so we don't fail if the table is missing, @@ -86,6 +90,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -341,3 +348,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken persists a new token that was issued through OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, +) (int64, error) { + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS + err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) + }) + return expiresAtMS, err +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) +} |