diff options
author | Bruce MacDonald <brucewmacdonald@gmail.com> | 2021-04-07 05:26:20 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-07 13:26:20 +0100 |
commit | d27607af78a53bda636f14f603b02b2952d6e1d8 (patch) | |
tree | c5c5488c7395a45af24ef598308ef7f6545515ca /userapi/storage/accounts/sqlite3 | |
parent | f8d3a762c49a1dafe4e484a2440ade2bb6ba32ac (diff) |
Implement OpenID module (#599) (#1812)
* Implement OpenID module (#599)
- Unrelated: change Riot references to Element in client API routing
Signed-off-by: Bruce MacDonald <contact@bruce-macdonald.com>
* OpenID module tweaks (#599)
- specify expiry is ms rather than vague ts
- add OpenID token lifetime to configuration
- use Go naming conventions for the path params
- store plaintext token rather than hash
- remove openid table sqllite mutex
* Add default OpenID token lifetime (#599)
* Update dendrite-config.yaml
Co-authored-by: Kegsay <kegsay@gmail.com>
Co-authored-by: Kegsay <kegan@matrix.org>
Diffstat (limited to 'userapi/storage/accounts/sqlite3')
-rw-r--r-- | userapi/storage/accounts/sqlite3/openid_table.go | 86 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/storage.go | 49 |
2 files changed, 124 insertions, 11 deletions
diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/accounts/sqlite3/openid_table.go new file mode 100644 index 00000000..80b9dd4c --- /dev/null +++ b/userapi/storage/accounts/sqlite3/openid_table.go @@ -0,0 +1,86 @@ +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +const openIDTokenSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS open_id_tokens ( + -- The value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- The Matrix user ID for this account + localpart TEXT NOT NULL, + -- When the token expires, as a unix timestamp (ms resolution). + token_expires_at_ms BIGINT NOT NULL +); +` + +const insertTokenSQL = "" + + "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + +const selectTokenSQL = "" + + "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" + +type tokenStatements struct { + db *sql.DB + insertTokenStmt *sql.Stmt + selectTokenStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + _, err = db.Exec(openIDTokenSchema) + if err != nil { + return err + } + if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { + return + } + if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertToken inserts a new OpenID Connect token to the DB. +// Returns new token, otherwise returns error if the token already exists. +func (s *tokenStatements) insertToken( + ctx context.Context, + txn *sql.Tx, + token, localpart string, + expiresAtMS int64, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) + _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + return +} + +// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB +// Returns the existing token's attributes, or err if no token is found +func (s *tokenStatements) selectOpenIDTokenAtrributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + var openIDTokenAttrs api.OpenIDTokenAttributes + err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( + &openIDTokenAttrs.UserID, + &openIDTokenAttrs.ExpiresAtMS, + ) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve token from the db") + } + return nil, err + } + + return &openIDTokenAttrs, nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 07cc68b3..c0f7118c 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -21,6 +21,7 @@ import ( "errors" "strconv" "sync" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -37,12 +38,14 @@ type Database struct { writer sqlutil.Writer sqlutil.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - accountDatas accountDataStatements - threepids threepidStatements - serverName gomatrixserverlib.ServerName - bcryptCost int + accounts accountsStatements + profiles profilesStatements + accountDatas accountDataStatements + threepids threepidStatements + openIDTokens tokenStatements + serverName gomatrixserverlib.ServerName + bcryptCost int + openIDTokenLifetimeMS int64 accountsMu sync.Mutex profilesMu sync.Mutex @@ -51,16 +54,17 @@ type Database struct { } // NewDatabase creates a new accounts and profiles database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } d := &Database{ - serverName: serverName, - db: db, - writer: sqlutil.NewExclusiveWriter(), - bcryptCost: bcryptCost, + serverName: serverName, + db: db, + writer: sqlutil.NewExclusiveWriter(), + bcryptCost: bcryptCost, + openIDTokenLifetimeMS: openIDTokenLifetimeMS, } // Create tables before executing migrations so we don't fail if the table is missing, @@ -90,6 +94,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.threepids.prepare(db); err != nil { return nil, err } + if err = d.openIDTokens.prepare(db, serverName); err != nil { + return nil, err + } return d, nil } @@ -379,3 +386,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { return d.accounts.deactivateAccount(ctx, localpart) } + +// CreateOpenIDToken persists a new token that was issued for OpenID Connect +func (d *Database) CreateOpenIDToken( + ctx context.Context, + token, localpart string, +) (int64, error) { + expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS + err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS) + }) + return expiresAtMS, err +} + +// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token +func (d *Database) GetOpenIDTokenAttributes( + ctx context.Context, + token string, +) (*api.OpenIDTokenAttributes, error) { + return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token) +} |