diff options
author | Kegsay <kegan@matrix.org> | 2020-02-13 17:27:33 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-13 17:27:33 +0000 |
commit | b6ea1bc67ab51667b9e139dd05e0778aca025501 (patch) | |
tree | 18569c317fd28544144c320ce844d93a8ff8ec5e /clientapi | |
parent | 6942ee1de0250235164cf0ce45570b7fc919669d (diff) |
Support sqlite in addition to postgres (#869)
* Move current work into single branch
* Initial massaging of clientapi etc (not working yet)
* Interfaces for accounts/devices databases
* Duplicate postgres package for sqlite3 (no changes made to it yet)
* Some keydb, accountdb, devicedb, common partition fixes, some more syncapi tweaking
* Fix accounts DB, device DB
* Update naffka dependency for SQLite
* Naffka SQLite
* Update naffka to latest master
* SQLite support for federationsender
* Mostly not-bad support for SQLite in syncapi (although there are problems where lots of events get classed incorrectly as backward extremities, probably because of IN/ANY clauses that are badly supported)
* Update Dockerfile -> Go 1.13.7, add build-base (as gcc and friends are needed for SQLite)
* Implement GET endpoints for account_data in clientapi
* Nuke filtering for now...
* Revert "Implement GET endpoints for account_data in clientapi"
This reverts commit 4d80dff4583d278620d9b3ed437e9fcd8d4674ee.
* Implement GET endpoints for account_data in clientapi (#861)
* Implement GET endpoints for account_data in clientapi
* Fix accountDB parameter
* Remove fmt.Println
* Fix insertAccountData SQLite query
* Fix accountDB storage interfaces
* Add empty push rules into account data on account creation (#862)
* Put SaveAccountData into the right function this time
* Not sure if roomserver is better or worse now
* sqlite work
* Allow empty last sent ID for the first event
* sqlite: room creation works
* Support sending messages
* Nuke fmt.println
* Move QueryVariadic etc into common, other device fixes
* Fix some linter issues
* Fix bugs
* Fix some linting errors
* Fix errcheck lint errors
* Make naffka use postgres as fallback, fix couple of compile errors
* What on earth happened to the /rooms/{roomID}/send/{eventType} routing
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'clientapi')
37 files changed, 2318 insertions, 608 deletions
diff --git a/clientapi/auth/storage/accounts/account_data_table.go b/clientapi/auth/storage/accounts/postgres/account_data_table.go index 1b7484d8..d0cfcc0c 100644 --- a/clientapi/auth/storage/accounts/account_data_table.go +++ b/clientapi/auth/storage/accounts/postgres/account_data_table.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package postgres import ( "context" diff --git a/clientapi/auth/storage/accounts/accounts_table.go b/clientapi/auth/storage/accounts/postgres/accounts_table.go index e86654ec..6b8ed372 100644 --- a/clientapi/auth/storage/accounts/accounts_table.go +++ b/clientapi/auth/storage/accounts/postgres/accounts_table.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package postgres import ( "context" diff --git a/clientapi/auth/storage/accounts/filter_table.go b/clientapi/auth/storage/accounts/postgres/filter_table.go index 2b07ef17..c54e4bc4 100644 --- a/clientapi/auth/storage/accounts/filter_table.go +++ b/clientapi/auth/storage/accounts/postgres/filter_table.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package postgres import ( "context" diff --git a/clientapi/auth/storage/accounts/membership_table.go b/clientapi/auth/storage/accounts/postgres/membership_table.go index 7b7c50ac..426c2d6a 100644 --- a/clientapi/auth/storage/accounts/membership_table.go +++ b/clientapi/auth/storage/accounts/postgres/membership_table.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package postgres import ( "context" diff --git a/clientapi/auth/storage/accounts/profile_table.go b/clientapi/auth/storage/accounts/postgres/profile_table.go index 157bb99b..38c76c40 100644 --- a/clientapi/auth/storage/accounts/profile_table.go +++ b/clientapi/auth/storage/accounts/postgres/profile_table.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package postgres import ( "context" diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go new file mode 100644 index 00000000..cb74d131 --- /dev/null +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -0,0 +1,392 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + + // Import the postgres database driver. + _ "github.com/lib/pq" +) + +// Database represents an account database +type Database struct { + db *sql.DB + common.PartitionOffsetStatements + accounts accountsStatements + profiles profilesStatements + memberships membershipStatements + accountDatas accountDataStatements + threepids threepidStatements + filter filterStatements + serverName gomatrixserverlib.ServerName +} + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + if db, err = sql.Open("postgres", dataSourceName); err != nil { + return nil, err + } + partitions := common.PartitionOffsetStatements{} + if err = partitions.Prepare(db, "account"); err != nil { + return nil, err + } + a := accountsStatements{} + if err = a.prepare(db, serverName); err != nil { + return nil, err + } + p := profilesStatements{} + if err = p.prepare(db); err != nil { + return nil, err + } + m := membershipStatements{} + if err = m.prepare(db); err != nil { + return nil, err + } + ac := accountDataStatements{} + if err = ac.prepare(db); err != nil { + return nil, err + } + t := threepidStatements{} + if err = t.prepare(db); err != nil { + return nil, err + } + f := filterStatements{} + if err = f.prepare(db); err != nil { + return nil, err + } + return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil +} + +// GetAccountByPassword returns the account associated with the given localpart and password. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*authtypes.Account, error) { + hash, err := d.accounts.selectPasswordHash(ctx, localpart) + if err != nil { + return nil, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} + +// GetProfileByLocalpart returns the profile associated with the given localpart. +// Returns sql.ErrNoRows if no profile exists which matches the given localpart. +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.profiles.selectProfileByLocalpart(ctx, localpart) +} + +// SetAvatarURL updates the avatar URL of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + return d.profiles.setAvatarURL(ctx, localpart, avatarURL) +} + +// SetDisplayName updates the display name of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + return d.profiles.setDisplayName(ctx, localpart, displayName) +} + +// CreateAccount makes a new account with the given login name and password, and creates an empty profile +// for this account. If no password is supplied, the account will be a passwordless account. If the +// account already exists, it will return nil, nil. +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (*authtypes.Account, error) { + var err error + + // Generate a password hash if this is not a password-less user + hash := "" + if plaintextPassword != "" { + hash, err = hashPassword(plaintextPassword) + if err != nil { + return nil, err + } + } + if err := d.profiles.insertProfile(ctx, localpart); err != nil { + if common.IsUniqueConstraintViolationErr(err) { + return nil, nil + } + return nil, err + } + if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`); err != nil { + return nil, err + } + return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) +} + +// SaveMembership saves the user matching a given localpart as a member of a given +// room. It also stores the ID of the membership event. +// If a membership already exists between the user and the room, or if the +// insert fails, returns the SQL error +func (d *Database) saveMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) error { + return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID) +} + +// removeMembershipsByEventIDs removes the memberships corresponding to the +// `join` membership events IDs in the eventIDs slice. +// If the removal fails, or if there is no membership to remove, returns an error +func (d *Database) removeMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) error { + return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs) +} + +// UpdateMemberships adds the "join" membership events included in a given state +// events array, and removes those which ID is included in a given array of events +// IDs. All of the process is run in a transaction, which commits only once/if every +// insertion and deletion has been successfully processed. +// Returns a SQL error if there was an issue with any part of the process +func (d *Database) UpdateMemberships( + ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { + return err + } + + for _, event := range eventsToAdd { + if err := d.newMembership(ctx, txn, event); err != nil { + return err + } + } + + return nil + }) +} + +// GetMembershipInRoomByLocalpart returns the membership for an user +// matching the given localpart if he is a member of the room matching roomID, +// if not sql.ErrNoRows is returned. +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipInRoomByLocalpart( + ctx context.Context, localpart, roomID string, +) (authtypes.Membership, error) { + return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) +} + +// GetMembershipsByLocalpart returns an array containing the memberships for all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + return d.memberships.selectMembershipsByLocalpart(ctx, localpart) +} + +// newMembership saves a new membership in the database. +// If the event isn't a valid m.room.member event with type `join`, does nothing. +// If an error occurred, returns the SQL error +func (d *Database) newMembership( + ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event, +) error { + if ev.Type() == "m.room.member" && ev.StateKey() != nil { + localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + return err + } + + // We only want state events from local users + if string(serverName) != string(d.serverName) { + return nil + } + + eventID := ev.EventID() + roomID := ev.RoomID() + membership, err := ev.Membership() + if err != nil { + return err + } + + // Only "join" membership events can be considered as new memberships + if membership == gomatrixserverlib.Join { + if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { + return err + } + } + } + return nil +} + +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType, content string, +) error { + return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + return d.accountDatas.selectAccountData(ctx, localpart) +} + +// GetAccountDataByType returns account data matching a given +// localpart, room ID and type. +// If no account data could be found, returns nil +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data *gomatrixserverlib.ClientEvent, err error) { + return d.accountDatas.selectAccountDataByType( + ctx, localpart, roomID, dataType, + ) +} + +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.accounts.selectNewNumericLocalpart(ctx) +} + +func hashPassword(plaintext string) (hash string, err error) { + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) + return string(hashBytes), err +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("This third-party identifier is already in use") + +// SaveThreePIDAssociation saves the association between a third party identifier +// and a local Matrix user (identified by the user's ID's local part). +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + user, err := d.threepids.selectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) + if err != nil { + return err + } + + if len(user) > 0 { + return Err3PIDInUse + } + + return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) + }) +} + +// RemoveThreePIDAssociation removes the association involving a given third-party +// identifier. +// If no association exists involving this third-party identifier, returns nothing. +// If there was a problem talking to the database, returns an error. +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + return d.threepids.deleteThreePID(ctx, threepid, medium) +} + +// GetLocalpartForThreePID looks up the localpart associated with a given third-party +// identifier. +// If no association involves the given third-party idenfitier, returns an empty +// string. +// Returns an error if there was a problem talking to the database. +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) +} + +// GetThreePIDsForLocalpart looks up the third-party identifiers associated with +// a given local user. +// If no association is known for this user, returns an empty slice. +// Returns an error if there was an issue talking to the database. +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) +} + +// GetFilter looks up the filter associated with a given local user and filter ID. +// Returns a filter structure. Otherwise returns an error if no such filter exists +// or if there was an error talking to the database. +func (d *Database) GetFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + return d.filter.selectFilter(ctx, localpart, filterID) +} + +// PutFilter puts the passed filter into the database. +// Returns the filterID as a string. Otherwise returns an error if something +// goes wrong. +func (d *Database) PutFilter( + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, +) (string, error) { + return d.filter.insertFilter(ctx, filter, localpart) +} + +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} + +// GetAccountByLocalpart returns the account associated with the given localpart. +// This function assumes the request is authenticated or the account data is used only internally. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +) (*authtypes.Account, error) { + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} diff --git a/clientapi/auth/storage/accounts/postgres/threepid_table.go b/clientapi/auth/storage/accounts/postgres/threepid_table.go new file mode 100644 index 00000000..851b4a90 --- /dev/null +++ b/clientapi/auth/storage/accounts/postgres/threepid_table.go @@ -0,0 +1,129 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const threepidSchema = ` +-- Stores data about third party identifiers +CREATE TABLE IF NOT EXISTS account_threepid ( + -- The third party identifier + threepid TEXT NOT NULL, + -- The 3PID medium + medium TEXT NOT NULL DEFAULT 'email', + -- The localpart of the Matrix user ID associated to this 3PID + localpart TEXT NOT NULL, + + PRIMARY KEY(threepid, medium) +); + +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +` + +const selectLocalpartForThreePIDSQL = "" + + "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + +const selectThreePIDsForLocalpartSQL = "" + + "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + +const insertThreePIDSQL = "" + + "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + +const deleteThreePIDSQL = "" + + "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + +type threepidStatements struct { + selectLocalpartForThreePIDStmt *sql.Stmt + selectThreePIDsForLocalpartStmt *sql.Stmt + insertThreePIDStmt *sql.Stmt + deleteThreePIDStmt *sql.Stmt +} + +func (s *threepidStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(threepidSchema) + if err != nil { + return + } + if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { + return + } + if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { + return + } + if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { + return + } + if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { + return + } + + return +} + +func (s *threepidStatements) selectLocalpartForThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string, +) (localpart string, err error) { + stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *threepidStatements) selectThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + if err != nil { + return + } + + threepids = []authtypes.ThreePID{} + for rows.Next() { + var threepid string + var medium string + if err = rows.Scan(&threepid, &medium); err != nil { + return + } + threepids = append(threepids, authtypes.ThreePID{ + Address: threepid, + Medium: medium, + }) + } + + return +} + +func (s *threepidStatements) insertThreePID( + ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, +) (err error) { + stmt := common.TxStmt(txn, s.insertThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + return +} + +func (s *threepidStatements) deleteThreePID( + ctx context.Context, threepid string, medium string) (err error) { + _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) + return +} diff --git a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go new file mode 100644 index 00000000..c2143881 --- /dev/null +++ b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go @@ -0,0 +1,141 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +const accountDataSchema = ` +-- Stores data about accounts data. +CREATE TABLE IF NOT EXISTS account_data ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL, + -- The room ID for this data (empty string if not specific to a room) + room_id TEXT, + -- The account data type + type TEXT NOT NULL, + -- The account data content + content TEXT NOT NULL, + + PRIMARY KEY(localpart, room_id, type) +); +` + +const insertAccountDataSQL = ` + INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 +` + +const selectAccountDataSQL = "" + + "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + +const selectAccountDataByTypeSQL = "" + + "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + +type accountDataStatements struct { + insertAccountDataStmt *sql.Stmt + selectAccountDataStmt *sql.Stmt + selectAccountDataByTypeStmt *sql.Stmt +} + +func (s *accountDataStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(accountDataSchema) + if err != nil { + return + } + if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { + return + } + if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { + return + } + if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { + return + } + return +} + +func (s *accountDataStatements) insertAccountData( + ctx context.Context, localpart, roomID, dataType, content string, +) (err error) { + stmt := s.insertAccountDataStmt + _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) + return +} + +func (s *accountDataStatements) selectAccountData( + ctx context.Context, localpart string, +) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + if err != nil { + return + } + + global = []gomatrixserverlib.ClientEvent{} + rooms = make(map[string][]gomatrixserverlib.ClientEvent) + + for rows.Next() { + var roomID string + var dataType string + var content []byte + + if err = rows.Scan(&roomID, &dataType, &content); err != nil { + return + } + + ac := gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, + } + + if len(roomID) > 0 { + rooms[roomID] = append(rooms[roomID], ac) + } else { + global = append(global, ac) + } + } + + return +} + +func (s *accountDataStatements) selectAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data *gomatrixserverlib.ClientEvent, err error) { + stmt := s.selectAccountDataByTypeStmt + var content []byte + + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + + return + } + + data = &gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, + } + + return +} diff --git a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go new file mode 100644 index 00000000..b029951f --- /dev/null +++ b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go @@ -0,0 +1,151 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/gomatrixserverlib" + + log "github.com/sirupsen/logrus" +) + +const accountsSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_accounts ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- When this account was first created, as a unix timestamp (ms resolution). + created_ts BIGINT NOT NULL, + -- The password hash for this account. Can be NULL if this is a passwordless account. + password_hash TEXT, + -- Identifies which application service this account belongs to, if any. + appservice_id TEXT + -- TODO: + -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? +); +` + +const insertAccountSQL = "" + + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + +const selectAccountByLocalpartSQL = "" + + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + +const selectPasswordHashSQL = "" + + "SELECT password_hash FROM account_accounts WHERE localpart = $1" + +const selectNewNumericLocalpartSQL = "" + + "SELECT COUNT(localpart) FROM account_accounts" + +// TODO: Update password + +type accountsStatements struct { + insertAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(accountsSchema) + if err != nil { + return + } + if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { + return + } + if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { + return + } + if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { + return + } + if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, +// this account will be passwordless. Returns an error if this account already exists. Returns the account +// on success. +func (s *accountsStatements) insertAccount( + ctx context.Context, localpart, hash, appserviceID string, +) (*authtypes.Account, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + stmt := s.insertAccountStmt + + var err error + if appserviceID == "" { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } + if err != nil { + return nil, err + } + + return &authtypes.Account{ + Localpart: localpart, + UserID: userutil.MakeUserID(localpart, s.serverName), + ServerName: s.serverName, + AppServiceID: appserviceID, + }, nil +} + +func (s *accountsStatements) selectPasswordHash( + ctx context.Context, localpart string, +) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + return +} + +func (s *accountsStatements) selectAccountByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Account, error) { + var appserviceIDPtr sql.NullString + var acc authtypes.Account + + stmt := s.selectAccountByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve user from the db") + } + return nil, err + } + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil +} + +func (s *accountsStatements) selectNewNumericLocalpart( + ctx context.Context, +) (id int64, err error) { + err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) + return +} diff --git a/clientapi/auth/storage/accounts/sqlite3/filter_table.go b/clientapi/auth/storage/accounts/sqlite3/filter_table.go new file mode 100644 index 00000000..691ead77 --- /dev/null +++ b/clientapi/auth/storage/accounts/sqlite3/filter_table.go @@ -0,0 +1,139 @@ +// Copyright 2017 Jan Christian Grünhage +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +const filterSchema = ` +-- Stores data about filters +CREATE TABLE IF NOT EXISTS account_filter ( + -- The filter + filter TEXT NOT NULL, + -- The ID + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The localpart of the Matrix user ID associated to this filter + localpart TEXT NOT NULL, + + UNIQUE (id, localpart) +); + +CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart); +` + +const selectFilterSQL = "" + + "SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2" + +const selectFilterIDByContentSQL = "" + + "SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2" + +const insertFilterSQL = "" + + "INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)" + +const selectLastInsertedFilterIDSQL = "" + + "SELECT id FROM account_filter WHERE rowid = last_insert_rowid()" + +type filterStatements struct { + selectFilterStmt *sql.Stmt + selectLastInsertedFilterIDStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt +} + +func (s *filterStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(filterSchema) + if err != nil { + return + } + if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { + return + } + if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil { + return + } + if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { + return + } + if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { + return + } + return +} + +func (s *filterStatements) selectFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + // Retrieve filter from database (stored as canonical JSON) + var filterData []byte + err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + if err != nil { + return nil, err + } + + // Unmarshal JSON into Filter struct + var filter gomatrixserverlib.Filter + if err = json.Unmarshal(filterData, &filter); err != nil { + return nil, err + } + return &filter, nil +} + +func (s *filterStatements) insertFilter( + ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, +) (filterID string, err error) { + var existingFilterID string + + // Serialise json + filterJSON, err := json.Marshal(filter) + if err != nil { + return "", err + } + // Remove whitespaces and sort JSON data + // needed to prevent from inserting the same filter multiple times + filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON) + if err != nil { + return "", err + } + + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // If it does, return the existing ID + if existingFilterID != "" { + return existingFilterID, err + } + + // Otherwise insert the filter and return the new ID + if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil { + return "", err + } + row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx) + if err := row.Scan(&filterID); err != nil { + return "", err + } + return +} diff --git a/clientapi/auth/storage/accounts/sqlite3/membership_table.go b/clientapi/auth/storage/accounts/sqlite3/membership_table.go new file mode 100644 index 00000000..8e5e69ba --- /dev/null +++ b/clientapi/auth/storage/accounts/sqlite3/membership_table.go @@ -0,0 +1,131 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const membershipSchema = ` +-- Stores data about users memberships to rooms. +CREATE TABLE IF NOT EXISTS account_memberships ( + -- The Matrix user ID localpart for the member + localpart TEXT NOT NULL, + -- The room this user is a member of + room_id TEXT NOT NULL, + -- The ID of the join membership event + event_id TEXT NOT NULL, + + -- A user can only be member of a room once + PRIMARY KEY (localpart, room_id), + + UNIQUE (event_id) +); +` + +const insertMembershipSQL = ` + INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3) + ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id +` + +const selectMembershipsByLocalpartSQL = "" + + "SELECT room_id, event_id FROM account_memberships WHERE localpart = $1" + +const selectMembershipInRoomByLocalpartSQL = "" + + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" + +const deleteMembershipsByEventIDsSQL = "" + + "DELETE FROM account_memberships WHERE event_id IN ($1)" + +type membershipStatements struct { + deleteMembershipsByEventIDsStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipInRoomByLocalpartStmt *sql.Stmt + selectMembershipsByLocalpartStmt *sql.Stmt +} + +func (s *membershipStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(membershipSchema) + if err != nil { + return + } + if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil { + return + } + if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil { + return + } + if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil { + return + } + if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { + return + } + return +} + +func (s *membershipStatements) insertMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) (err error) { + stmt := txn.Stmt(s.insertMembershipStmt) + _, err = stmt.ExecContext(ctx, localpart, roomID, eventID) + return +} + +func (s *membershipStatements) deleteMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (err error) { + stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt) + _, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs)) + return +} + +func (s *membershipStatements) selectMembershipInRoomByLocalpart( + ctx context.Context, localpart, roomID string, +) (authtypes.Membership, error) { + membership := authtypes.Membership{Localpart: localpart, RoomID: roomID} + stmt := s.selectMembershipInRoomByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID) + + return membership, err +} + +func (s *membershipStatements) selectMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + stmt := s.selectMembershipsByLocalpartStmt + rows, err := stmt.QueryContext(ctx, localpart) + if err != nil { + return + } + + memberships = []authtypes.Membership{} + + defer rows.Close() // nolint: errcheck + for rows.Next() { + var m authtypes.Membership + m.Localpart = localpart + if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { + return nil, err + } + memberships = append(memberships, m) + } + + return +} diff --git a/clientapi/auth/storage/accounts/sqlite3/profile_table.go b/clientapi/auth/storage/accounts/sqlite3/profile_table.go new file mode 100644 index 00000000..7af8307e --- /dev/null +++ b/clientapi/auth/storage/accounts/sqlite3/profile_table.go @@ -0,0 +1,107 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const profilesSchema = ` +-- Stores data about accounts profiles. +CREATE TABLE IF NOT EXISTS account_profiles ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- The display name for this account + display_name TEXT, + -- The URL of the avatar for this account + avatar_url TEXT +); +` + +const insertProfileSQL = "" + + "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + +const selectProfileByLocalpartSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + +const setAvatarURLSQL = "" + + "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + +const setDisplayNameSQL = "" + + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + +type profilesStatements struct { + insertProfileStmt *sql.Stmt + selectProfileByLocalpartStmt *sql.Stmt + setAvatarURLStmt *sql.Stmt + setDisplayNameStmt *sql.Stmt +} + +func (s *profilesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(profilesSchema) + if err != nil { + return + } + if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { + return + } + if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { + return + } + if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { + return + } + if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { + return + } + return +} + +func (s *profilesStatements) insertProfile( + ctx context.Context, localpart string, +) (err error) { + _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") + return +} + +func (s *profilesStatements) selectProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + var profile authtypes.Profile + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( + &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + ) + if err != nil { + return nil, err + } + return &profile, nil +} + +func (s *profilesStatements) setAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) (err error) { + _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) + return +} + +func (s *profilesStatements) setDisplayName( + ctx context.Context, localpart string, displayName string, +) (err error) { + _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) + return +} diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go new file mode 100644 index 00000000..199c4606 --- /dev/null +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -0,0 +1,392 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + + // Import the postgres database driver. + _ "github.com/mattn/go-sqlite3" +) + +// Database represents an account database +type Database struct { + db *sql.DB + common.PartitionOffsetStatements + accounts accountsStatements + profiles profilesStatements + memberships membershipStatements + accountDatas accountDataStatements + threepids threepidStatements + filter filterStatements + serverName gomatrixserverlib.ServerName +} + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + if db, err = sql.Open("sqlite3", dataSourceName); err != nil { + return nil, err + } + partitions := common.PartitionOffsetStatements{} + if err = partitions.Prepare(db, "account"); err != nil { + return nil, err + } + a := accountsStatements{} + if err = a.prepare(db, serverName); err != nil { + return nil, err + } + p := profilesStatements{} + if err = p.prepare(db); err != nil { + return nil, err + } + m := membershipStatements{} + if err = m.prepare(db); err != nil { + return nil, err + } + ac := accountDataStatements{} + if err = ac.prepare(db); err != nil { + return nil, err + } + t := threepidStatements{} + if err = t.prepare(db); err != nil { + return nil, err + } + f := filterStatements{} + if err = f.prepare(db); err != nil { + return nil, err + } + return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil +} + +// GetAccountByPassword returns the account associated with the given localpart and password. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*authtypes.Account, error) { + hash, err := d.accounts.selectPasswordHash(ctx, localpart) + if err != nil { + return nil, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} + +// GetProfileByLocalpart returns the profile associated with the given localpart. +// Returns sql.ErrNoRows if no profile exists which matches the given localpart. +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.profiles.selectProfileByLocalpart(ctx, localpart) +} + +// SetAvatarURL updates the avatar URL of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + return d.profiles.setAvatarURL(ctx, localpart, avatarURL) +} + +// SetDisplayName updates the display name of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + return d.profiles.setDisplayName(ctx, localpart, displayName) +} + +// CreateAccount makes a new account with the given login name and password, and creates an empty profile +// for this account. If no password is supplied, the account will be a passwordless account. If the +// account already exists, it will return nil, nil. +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (*authtypes.Account, error) { + var err error + + // Generate a password hash if this is not a password-less user + hash := "" + if plaintextPassword != "" { + hash, err = hashPassword(plaintextPassword) + if err != nil { + return nil, err + } + } + if err := d.profiles.insertProfile(ctx, localpart); err != nil { + if common.IsUniqueConstraintViolationErr(err) { + return nil, nil + } + return nil, err + } + if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`); err != nil { + return nil, err + } + return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) +} + +// SaveMembership saves the user matching a given localpart as a member of a given +// room. It also stores the ID of the membership event. +// If a membership already exists between the user and the room, or if the +// insert fails, returns the SQL error +func (d *Database) saveMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) error { + return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID) +} + +// removeMembershipsByEventIDs removes the memberships corresponding to the +// `join` membership events IDs in the eventIDs slice. +// If the removal fails, or if there is no membership to remove, returns an error +func (d *Database) removeMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) error { + return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs) +} + +// UpdateMemberships adds the "join" membership events included in a given state +// events array, and removes those which ID is included in a given array of events +// IDs. All of the process is run in a transaction, which commits only once/if every +// insertion and deletion has been successfully processed. +// Returns a SQL error if there was an issue with any part of the process +func (d *Database) UpdateMemberships( + ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { + return err + } + + for _, event := range eventsToAdd { + if err := d.newMembership(ctx, txn, event); err != nil { + return err + } + } + + return nil + }) +} + +// GetMembershipInRoomByLocalpart returns the membership for an user +// matching the given localpart if he is a member of the room matching roomID, +// if not sql.ErrNoRows is returned. +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipInRoomByLocalpart( + ctx context.Context, localpart, roomID string, +) (authtypes.Membership, error) { + return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) +} + +// GetMembershipsByLocalpart returns an array containing the memberships for all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + return d.memberships.selectMembershipsByLocalpart(ctx, localpart) +} + +// newMembership saves a new membership in the database. +// If the event isn't a valid m.room.member event with type `join`, does nothing. +// If an error occurred, returns the SQL error +func (d *Database) newMembership( + ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event, +) error { + if ev.Type() == "m.room.member" && ev.StateKey() != nil { + localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + return err + } + + // We only want state events from local users + if string(serverName) != string(d.serverName) { + return nil + } + + eventID := ev.EventID() + roomID := ev.RoomID() + membership, err := ev.Membership() + if err != nil { + return err + } + + // Only "join" membership events can be considered as new memberships + if membership == gomatrixserverlib.Join { + if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { + return err + } + } + } + return nil +} + +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType, content string, +) error { + return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + return d.accountDatas.selectAccountData(ctx, localpart) +} + +// GetAccountDataByType returns account data matching a given +// localpart, room ID and type. +// If no account data could be found, returns nil +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data *gomatrixserverlib.ClientEvent, err error) { + return d.accountDatas.selectAccountDataByType( + ctx, localpart, roomID, dataType, + ) +} + +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.accounts.selectNewNumericLocalpart(ctx) +} + +func hashPassword(plaintext string) (hash string, err error) { + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) + return string(hashBytes), err +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("This third-party identifier is already in use") + +// SaveThreePIDAssociation saves the association between a third party identifier +// and a local Matrix user (identified by the user's ID's local part). +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + user, err := d.threepids.selectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) + if err != nil { + return err + } + + if len(user) > 0 { + return Err3PIDInUse + } + + return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) + }) +} + +// RemoveThreePIDAssociation removes the association involving a given third-party +// identifier. +// If no association exists involving this third-party identifier, returns nothing. +// If there was a problem talking to the database, returns an error. +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + return d.threepids.deleteThreePID(ctx, threepid, medium) +} + +// GetLocalpartForThreePID looks up the localpart associated with a given third-party +// identifier. +// If no association involves the given third-party idenfitier, returns an empty +// string. +// Returns an error if there was a problem talking to the database. +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) +} + +// GetThreePIDsForLocalpart looks up the third-party identifiers associated with +// a given local user. +// If no association is known for this user, returns an empty slice. +// Returns an error if there was an issue talking to the database. +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) +} + +// GetFilter looks up the filter associated with a given local user and filter ID. +// Returns a filter structure. Otherwise returns an error if no such filter exists +// or if there was an error talking to the database. +func (d *Database) GetFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + return d.filter.selectFilter(ctx, localpart, filterID) +} + +// PutFilter puts the passed filter into the database. +// Returns the filterID as a string. Otherwise returns an error if something +// goes wrong. +func (d *Database) PutFilter( + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, +) (string, error) { + return d.filter.insertFilter(ctx, filter, localpart) +} + +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} + +// GetAccountByLocalpart returns the account associated with the given localpart. +// This function assumes the request is authenticated or the account data is used only internally. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +) (*authtypes.Account, error) { + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} diff --git a/clientapi/auth/storage/accounts/threepid_table.go b/clientapi/auth/storage/accounts/sqlite3/threepid_table.go index a03aa4f8..53f6408d 100644 --- a/clientapi/auth/storage/accounts/threepid_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/threepid_table.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package accounts +package sqlite3 import ( "context" diff --git a/clientapi/auth/storage/accounts/storage.go b/clientapi/auth/storage/accounts/storage.go index 7cfc63c0..1dfd5f1f 100644 --- a/clientapi/auth/storage/accounts/storage.go +++ b/clientapi/auth/storage/accounts/storage.go @@ -1,392 +1,56 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package accounts import ( "context" - "database/sql" "errors" + "net/url" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" - - // Import the postgres database driver. - _ "github.com/lib/pq" ) -// Database represents an account database -type Database struct { - db *sql.DB - common.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - memberships membershipStatements - accountDatas accountDataStatements - threepids threepidStatements - filter filterStatements - serverName gomatrixserverlib.ServerName -} - -// NewDatabase creates a new accounts and profiles database -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { - var db *sql.DB - var err error - if db, err = sql.Open("postgres", dataSourceName); err != nil { - return nil, err - } - partitions := common.PartitionOffsetStatements{} - if err = partitions.Prepare(db, "account"); err != nil { - return nil, err - } - a := accountsStatements{} - if err = a.prepare(db, serverName); err != nil { - return nil, err - } - p := profilesStatements{} - if err = p.prepare(db); err != nil { - return nil, err - } - m := membershipStatements{} - if err = m.prepare(db); err != nil { - return nil, err - } - ac := accountDataStatements{} - if err = ac.prepare(db); err != nil { - return nil, err - } - t := threepidStatements{} - if err = t.prepare(db); err != nil { - return nil, err - } - f := filterStatements{} - if err = f.prepare(db); err != nil { - return nil, err - } - return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil -} - -// GetAccountByPassword returns the account associated with the given localpart and password. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, -) (*authtypes.Account, error) { - hash, err := d.accounts.selectPasswordHash(ctx, localpart) +type Database interface { + common.PartitionStorer + GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error) + GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error + SetDisplayName(ctx context.Context, localpart string, displayName string) error + CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) + UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error + GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) + GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) + SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error + GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) + GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) + GetNewNumericLocalpart(ctx context.Context) (int64, error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) + PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) + CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error) +} + +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { + uri, err := url.Parse(dataSourceName) if err != nil { - return nil, err + return postgres.NewDatabase(dataSourceName, serverName) } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { - return nil, err + switch uri.Scheme { + case "postgres": + return postgres.NewDatabase(dataSourceName, serverName) + case "file": + return sqlite3.NewDatabase(dataSourceName, serverName) + default: + return postgres.NewDatabase(dataSourceName, serverName) } - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} - -// GetProfileByLocalpart returns the profile associated with the given localpart. -// Returns sql.ErrNoRows if no profile exists which matches the given localpart. -func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, -) (*authtypes.Profile, error) { - return d.profiles.selectProfileByLocalpart(ctx, localpart) -} - -// SetAvatarURL updates the avatar URL of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, -) error { - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) -} - -// SetDisplayName updates the display name of the profile associated with the given -// localpart. Returns an error if something went wrong with the SQL query -func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, -) error { - return d.profiles.setDisplayName(ctx, localpart, displayName) -} - -// CreateAccount makes a new account with the given login name and password, and creates an empty profile -// for this account. If no password is supplied, the account will be a passwordless account. If the -// account already exists, it will return nil, nil. -func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (*authtypes.Account, error) { - var err error - - // Generate a password hash if this is not a password-less user - hash := "" - if plaintextPassword != "" { - hash, err = hashPassword(plaintextPassword) - if err != nil { - return nil, err - } - } - if err := d.profiles.insertProfile(ctx, localpart); err != nil { - if common.IsUniqueConstraintViolationErr(err) { - return nil, nil - } - return nil, err - } - if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`); err != nil { - return nil, err - } - return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) -} - -// SaveMembership saves the user matching a given localpart as a member of a given -// room. It also stores the ID of the membership event. -// If a membership already exists between the user and the room, or if the -// insert fails, returns the SQL error -func (d *Database) saveMembership( - ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, -) error { - return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID) -} - -// removeMembershipsByEventIDs removes the memberships corresponding to the -// `join` membership events IDs in the eventIDs slice. -// If the removal fails, or if there is no membership to remove, returns an error -func (d *Database) removeMembershipsByEventIDs( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) error { - return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs) -} - -// UpdateMemberships adds the "join" membership events included in a given state -// events array, and removes those which ID is included in a given array of events -// IDs. All of the process is run in a transaction, which commits only once/if every -// insertion and deletion has been successfully processed. -// Returns a SQL error if there was an issue with any part of the process -func (d *Database) UpdateMemberships( - ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, -) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { - return err - } - - for _, event := range eventsToAdd { - if err := d.newMembership(ctx, txn, event); err != nil { - return err - } - } - - return nil - }) -} - -// GetMembershipInRoomByLocalpart returns the membership for an user -// matching the given localpart if he is a member of the room matching roomID, -// if not sql.ErrNoRows is returned. -// If there was an issue during the retrieval, returns the SQL error -func (d *Database) GetMembershipInRoomByLocalpart( - ctx context.Context, localpart, roomID string, -) (authtypes.Membership, error) { - return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) -} - -// GetMembershipsByLocalpart returns an array containing the memberships for all -// the rooms a user matching a given localpart is a member of -// If no membership match the given localpart, returns an empty array -// If there was an issue during the retrieval, returns the SQL error -func (d *Database) GetMembershipsByLocalpart( - ctx context.Context, localpart string, -) (memberships []authtypes.Membership, err error) { - return d.memberships.selectMembershipsByLocalpart(ctx, localpart) -} - -// newMembership saves a new membership in the database. -// If the event isn't a valid m.room.member event with type `join`, does nothing. -// If an error occurred, returns the SQL error -func (d *Database) newMembership( - ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event, -) error { - if ev.Type() == "m.room.member" && ev.StateKey() != nil { - localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) - if err != nil { - return err - } - - // We only want state events from local users - if string(serverName) != string(d.serverName) { - return nil - } - - eventID := ev.EventID() - roomID := ev.RoomID() - membership, err := ev.Membership() - if err != nil { - return err - } - - // Only "join" membership events can be considered as new memberships - if membership == gomatrixserverlib.Join { - if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { - return err - } - } - } - return nil -} - -// SaveAccountData saves new account data for a given user and a given room. -// If the account data is not specific to a room, the room ID should be an empty string -// If an account data already exists for a given set (user, room, data type), it will -// update the corresponding row with the new content -// Returns a SQL error if there was an issue with the insertion/update -func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, -) error { - return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) -} - -// GetAccountData returns account data related to a given localpart -// If no account data could be found, returns an empty arrays -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, - err error, -) { - return d.accountDatas.selectAccountData(ctx, localpart) -} - -// GetAccountDataByType returns account data matching a given -// localpart, room ID and type. -// If no account data could be found, returns nil -// Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { - return d.accountDatas.selectAccountDataByType( - ctx, localpart, roomID, dataType, - ) -} - -// GetNewNumericLocalpart generates and returns a new unused numeric localpart -func (d *Database) GetNewNumericLocalpart( - ctx context.Context, -) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx) -} - -func hashPassword(plaintext string) (hash string, err error) { - hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) - return string(hashBytes), err } // Err3PIDInUse is the error returned when trying to save an association involving // a third-party identifier which is already associated to a local user. var Err3PIDInUse = errors.New("This third-party identifier is already in use") - -// SaveThreePIDAssociation saves the association between a third party identifier -// and a local Matrix user (identified by the user's ID's local part). -// If the third-party identifier is already part of an association, returns Err3PIDInUse. -// Returns an error if there was a problem talking to the database. -func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, -) (err error) { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { - user, err := d.threepids.selectLocalpartForThreePID( - ctx, txn, threepid, medium, - ) - if err != nil { - return err - } - - if len(user) > 0 { - return Err3PIDInUse - } - - return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) - }) -} - -// RemoveThreePIDAssociation removes the association involving a given third-party -// identifier. -// If no association exists involving this third-party identifier, returns nothing. -// If there was a problem talking to the database, returns an error. -func (d *Database) RemoveThreePIDAssociation( - ctx context.Context, threepid string, medium string, -) (err error) { - return d.threepids.deleteThreePID(ctx, threepid, medium) -} - -// GetLocalpartForThreePID looks up the localpart associated with a given third-party -// identifier. -// If no association involves the given third-party idenfitier, returns an empty -// string. -// Returns an error if there was a problem talking to the database. -func (d *Database) GetLocalpartForThreePID( - ctx context.Context, threepid string, medium string, -) (localpart string, err error) { - return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) -} - -// GetThreePIDsForLocalpart looks up the third-party identifiers associated with -// a given local user. -// If no association is known for this user, returns an empty slice. -// Returns an error if there was an issue talking to the database. -func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, -) (threepids []authtypes.ThreePID, err error) { - return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) -} - -// GetFilter looks up the filter associated with a given local user and filter ID. -// Returns a filter structure. Otherwise returns an error if no such filter exists -// or if there was an error talking to the database. -func (d *Database) GetFilter( - ctx context.Context, localpart string, filterID string, -) (*gomatrixserverlib.Filter, error) { - return d.filter.selectFilter(ctx, localpart, filterID) -} - -// PutFilter puts the passed filter into the database. -// Returns the filterID as a string. Otherwise returns an error if something -// goes wrong. -func (d *Database) PutFilter( - ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, -) (string, error) { - return d.filter.insertFilter(ctx, filter, localpart) -} - -// CheckAccountAvailability checks if the username/localpart is already present -// in the database. -// If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) - if err == sql.ErrNoRows { - return true, nil - } - return false, err -} - -// GetAccountByLocalpart returns the account associated with the given localpart. -// This function assumes the request is authenticated or the account data is used only internally. -// Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*authtypes.Account, error) { - return d.accounts.selectAccountByLocalpart(ctx, localpart) -} diff --git a/clientapi/auth/storage/devices/devices_table.go b/clientapi/auth/storage/devices/postgres/devices_table.go index 99741247..349bf1ef 100644 --- a/clientapi/auth/storage/devices/devices_table.go +++ b/clientapi/auth/storage/devices/postgres/devices_table.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package devices +package postgres import ( "context" diff --git a/clientapi/auth/storage/devices/postgres/storage.go b/clientapi/auth/storage/devices/postgres/storage.go new file mode 100644 index 00000000..221c3998 --- /dev/null +++ b/clientapi/auth/storage/devices/postgres/storage.go @@ -0,0 +1,182 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/gomatrixserverlib" +) + +// The length of generated device IDs +var deviceIDByteLength = 6 + +// Database represents a device database. +type Database struct { + db *sql.DB + devices devicesStatements +} + +// NewDatabase creates a new device database +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + if db, err = sql.Open("postgres", dataSourceName); err != nil { + return nil, err + } + d := devicesStatements{} + if err = d.prepare(db, serverName); err != nil { + return nil, err + } + return &Database{db, d}, nil +} + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*authtypes.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*authtypes.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]authtypes.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, localpart) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, +) (dev *authtypes.Device, returnErr error) { + if deviceID != nil { + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/clientapi/auth/storage/devices/sqlite3/devices_table.go new file mode 100644 index 00000000..dc88890d --- /dev/null +++ b/clientapi/auth/storage/devices/sqlite3/devices_table.go @@ -0,0 +1,243 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/common" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const devicesSchema = ` +-- This sequence is used for automatic allocation of session_id. +-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; + +-- Stores data about devices. +CREATE TABLE IF NOT EXISTS device_devices ( + access_token TEXT PRIMARY KEY, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + created_ts BIGINT, + display_name TEXT, + + UNIQUE (localpart, device_id) +); +` + +const insertDeviceSQL = "" + + "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + +const selectDevicesCountSQL = "" + + "SELECT COUNT(access_token) FROM device_devices" + +const selectDeviceByTokenSQL = "" + + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + +const selectDeviceByIDSQL = "" + + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + +const selectDevicesByLocalpartSQL = "" + + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + +const updateDeviceNameSQL = "" + + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + +const deleteDeviceSQL = "" + + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + +const deleteDevicesByLocalpartSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1" + +const deleteDevicesSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" + +type devicesStatements struct { + db *sql.DB + insertDeviceStmt *sql.Stmt + selectDevicesCountStmt *sql.Stmt + selectDeviceByTokenStmt *sql.Stmt + selectDeviceByIDStmt *sql.Stmt + selectDevicesByLocalpartStmt *sql.Stmt + updateDeviceNameStmt *sql.Stmt + deleteDeviceStmt *sql.Stmt + deleteDevicesByLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + _, err = db.Exec(devicesSchema) + if err != nil { + return + } + if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { + return + } + if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { + return + } + if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { + return + } + if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { + return + } + if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { + return + } + if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { + return + } + if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { + return + } + if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertDevice creates a new device. Returns an error if any device with the same access token already exists. +// Returns an error if the user already has a device with the given device ID. +// Returns the device on success. +func (s *devicesStatements) insertDevice( + ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, + displayName *string, +) (*authtypes.Device, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + var sessionID int64 + countStmt := common.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := common.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return nil, err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { + return nil, err + } + return &authtypes.Device{ + ID: id, + UserID: userutil.MakeUserID(localpart, s.serverName), + AccessToken: accessToken, + SessionID: sessionID, + }, nil +} + +func (s *devicesStatements) deleteDevice( + ctx context.Context, txn *sql.Tx, id, localpart string, +) error { + stmt := common.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err +} + +func (s *devicesStatements) deleteDevices( + ctx context.Context, txn *sql.Tx, localpart string, devices []string, +) error { + orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1) + prep, err := s.db.Prepare(orig) + if err != nil { + return err + } + stmt := common.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + params = append(params, params...) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *devicesStatements) deleteDevicesByLocalpart( + ctx context.Context, txn *sql.Tx, localpart string, +) error { + stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err +} + +func (s *devicesStatements) updateDeviceName( + ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, +) error { + stmt := common.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err +} + +func (s *devicesStatements) selectDeviceByToken( + ctx context.Context, accessToken string, +) (*authtypes.Device, error) { + var dev authtypes.Device + var localpart string + stmt := s.selectDeviceByTokenStmt + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + if err == nil { + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.AccessToken = accessToken + } + return &dev, err +} + +// selectDeviceByID retrieves a device from the database with the given user +// localpart and deviceID +func (s *devicesStatements) selectDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*authtypes.Device, error) { + var dev authtypes.Device + var created sql.NullInt64 + stmt := s.selectDeviceByIDStmt + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created) + if err == nil { + dev.ID = deviceID + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + } + return &dev, err +} + +func (s *devicesStatements) selectDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]authtypes.Device, error) { + devices := []authtypes.Device{} + + rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + + if err != nil { + return devices, err + } + + for rows.Next() { + var dev authtypes.Device + err = rows.Scan(&dev.ID) + if err != nil { + return devices, err + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + + return devices, nil +} diff --git a/clientapi/auth/storage/devices/sqlite3/storage.go b/clientapi/auth/storage/devices/sqlite3/storage.go new file mode 100644 index 00000000..e1ce6f00 --- /dev/null +++ b/clientapi/auth/storage/devices/sqlite3/storage.go @@ -0,0 +1,184 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/gomatrixserverlib" + + _ "github.com/mattn/go-sqlite3" +) + +// The length of generated device IDs +var deviceIDByteLength = 6 + +// Database represents a device database. +type Database struct { + db *sql.DB + devices devicesStatements +} + +// NewDatabase creates a new device database +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + if db, err = sql.Open("sqlite3", dataSourceName); err != nil { + return nil, err + } + d := devicesStatements{} + if err = d.prepare(db, serverName); err != nil { + return nil, err + } + return &Database{db, d}, nil +} + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*authtypes.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*authtypes.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]authtypes.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, localpart) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, +) (dev *authtypes.Device, returnErr error) { + if deviceID != nil { + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart string, +) error { + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} diff --git a/clientapi/auth/storage/devices/storage.go b/clientapi/auth/storage/devices/storage.go index 150180c1..82f75640 100644 --- a/clientapi/auth/storage/devices/storage.go +++ b/clientapi/auth/storage/devices/storage.go @@ -1,182 +1,37 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package devices import ( "context" - "crypto/rand" - "database/sql" - "encoding/base64" + "net/url" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres" + "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3" "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 - -// Database represents a device database. -type Database struct { - db *sql.DB - devices devicesStatements -} - -// NewDatabase creates a new device database -func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { - var db *sql.DB - var err error - if db, err = sql.Open("postgres", dataSourceName); err != nil { - return nil, err - } - d := devicesStatements{} - if err = d.prepare(db, serverName); err != nil { - return nil, err - } - return &Database{db, d}, nil -} - -// GetDeviceByAccessToken returns the device matching the given access token. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByAccessToken( - ctx context.Context, token string, -) (*authtypes.Device, error) { - return d.devices.selectDeviceByToken(ctx, token) -} - -// GetDeviceByID returns the device matching the given ID. -// Returns sql.ErrNoRows if no matching device was found. -func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, -) (*authtypes.Device, error) { - return d.devices.selectDeviceByID(ctx, localpart, deviceID) -} - -// GetDevicesByLocalpart returns the devices matching the given localpart. -func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, -) ([]authtypes.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, localpart) -} - -// CreateDevice makes a new device associated with the given user ID localpart. -// If there is already a device with the same device ID for this user, that access token will be revoked -// and replaced with the given accessToken. If the given accessToken is already in use for another device, -// an error will be returned. -// If no device ID is given one is generated. -// Returns the device on success. -func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, -) (dev *authtypes.Device, returnErr error) { - if deviceID != nil { - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { - return err - } - - dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) - return err - }) - } else { - // We generate device IDs in a loop in case its already taken. - // We cap this at going round 5 times to ensure we don't spin forever - var newDeviceID string - for i := 1; i <= 5; i++ { - newDeviceID, returnErr = generateDeviceID() - if returnErr != nil { - return - } - - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { - var err error - dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) - return err - }) - if returnErr == nil { - return - } - } - } - return +type Database interface { + GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error) + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + RemoveDevice(ctx context.Context, deviceID, localpart string) error + RemoveDevices(ctx context.Context, localpart string, devices []string) error + RemoveAllDevices(ctx context.Context, localpart string) error } -// generateDeviceID creates a new device id. Returns an error if failed to generate -// random bytes. -func generateDeviceID() (string, error) { - b := make([]byte, deviceIDByteLength) - _, err := rand.Read(b) +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) { + uri, err := url.Parse(dataSourceName) if err != nil { - return "", err + return postgres.NewDatabase(dataSourceName, serverName) + } + switch uri.Scheme { + case "postgres": + return postgres.NewDatabase(dataSourceName, serverName) + case "file": + return sqlite3.NewDatabase(dataSourceName, serverName) + default: + return postgres.NewDatabase(dataSourceName, serverName) } - // url-safe no padding - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// UpdateDevice updates the given device with the display name. -// Returns SQL error if there are problems and nil on success. -func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, -) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) - }) -} - -// RemoveDevice revokes a device by deleting the entry in the database -// matching with the given device ID and user ID localpart. -// If the device doesn't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevice( - ctx context.Context, deviceID, localpart string, -) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveDevices revokes one or more devices by deleting the entry in the database -// matching with the given device IDs and user ID localpart. -// If the devices don't exist, it will not return an error -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, -) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { - return err - } - return nil - }) -} - -// RemoveAllDevices revokes devices by deleting the entry in the -// database matching the given user ID localpart. -// If something went wrong during the deletion, it will return the SQL error. -func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, -) error { - return common.WithTransaction(d.db, func(txn *sql.Tx) error { - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { - return err - } - return nil - }) } diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index c911fecc..bb44e016 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -34,8 +34,8 @@ import ( // component. func SetupClientAPIComponent( base *basecomponent.BaseDendrite, - deviceDB *devices.Database, - accountsDB *accounts.Database, + deviceDB devices.Database, + accountsDB accounts.Database, federation *gomatrixserverlib.FederationClient, keyRing *gomatrixserverlib.KeyRing, aliasAPI roomserverAPI.RoomserverAliasAPI, diff --git a/clientapi/consumers/roomserver.go b/clientapi/consumers/roomserver.go index 0ee7c6bf..a6528151 100644 --- a/clientapi/consumers/roomserver.go +++ b/clientapi/consumers/roomserver.go @@ -31,7 +31,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { roomServerConsumer *common.ContinualConsumer - db *accounts.Database + db accounts.Database query api.RoomserverQueryAPI serverName string } @@ -40,7 +40,7 @@ type OutputRoomEventConsumer struct { func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, - store *accounts.Database, + store accounts.Database, queryAPI api.RoomserverQueryAPI, ) *OutputRoomEventConsumer { diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index bbc8c258..8ae9de2d 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -30,7 +30,7 @@ import ( // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *authtypes.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -62,7 +62,7 @@ func GetAccountData( // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *authtypes.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index f6f06421..2b1245b9 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -135,7 +135,7 @@ type fledglingEvent struct { func CreateRoom( req *http.Request, device *authtypes.Device, cfg *config.Dendrite, producer *producers.RoomserverProducer, - accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, + accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { // TODO (#267): Check room ID doesn't clash with an existing one, and we @@ -149,7 +149,7 @@ func CreateRoom( func createRoom( req *http.Request, device *authtypes.Device, cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer, - accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, + accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI, asAPI appserviceAPI.AppServiceQueryAPI, ) util.JSONResponse { logger := util.GetLogger(req.Context()) diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index eb7cd0b0..9b8647cd 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -46,7 +46,7 @@ type devicesDeleteJSON struct { // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( - req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *authtypes.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -76,7 +76,7 @@ func GetDeviceByID( // GetDevicesByLocalpart handles /devices func GetDevicesByLocalpart( - req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *authtypes.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -107,7 +107,7 @@ func GetDevicesByLocalpart( // UpdateDeviceByID handles PUT on /devices/{deviceID} func UpdateDeviceByID( - req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *authtypes.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -153,7 +153,7 @@ func UpdateDeviceByID( // DeleteDeviceById handles DELETE requests to /devices/{deviceId} func DeleteDeviceById( - req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *authtypes.Device, deviceID string, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) @@ -176,7 +176,7 @@ func DeleteDeviceById( // DeleteDevices handles POST requests to /delete_devices func DeleteDevices( - req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *authtypes.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/filter.go b/clientapi/routing/filter.go index eec501ff..583b2395 100644 --- a/clientapi/routing/filter.go +++ b/clientapi/routing/filter.go @@ -27,7 +27,7 @@ import ( // GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId} func GetFilter( - req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string, + req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ @@ -63,7 +63,7 @@ type filterResponse struct { //PutFilter implements POST /_matrix/client/r0/user/{userId}/filter func PutFilter( - req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, + req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, ) util.JSONResponse { if userID != device.UserID { return util.JSONResponse{ diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index 8b3f3740..5e6f3e55 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -45,7 +45,7 @@ func JoinRoomByIDOrAlias( queryAPI roomserverAPI.RoomserverQueryAPI, aliasAPI roomserverAPI.RoomserverAliasAPI, keyRing gomatrixserverlib.KeyRing, - accountDB *accounts.Database, + accountDB accounts.Database, ) util.JSONResponse { var content map[string]interface{} // must be a JSON object if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil { diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 2f4fb83c..b8364ed9 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -70,7 +70,7 @@ func passwordLogin() loginFlows { // Login implements GET and POST /login func Login( - req *http.Request, accountDB *accounts.Database, deviceDB *devices.Database, + req *http.Request, accountDB accounts.Database, deviceDB devices.Database, cfg *config.Dendrite, ) util.JSONResponse { if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options @@ -153,7 +153,7 @@ func Login( func getDevice( ctx context.Context, r passwordRequest, - deviceDB *devices.Database, + deviceDB devices.Database, acc *authtypes.Account, token string, ) (dev *authtypes.Device, err error) { diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 3294fbcd..0ac9ca4a 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -26,7 +26,7 @@ import ( // Logout handles POST /logout func Logout( - req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *authtypes.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -45,7 +45,7 @@ func Logout( // LogoutAll handles POST /logout/all func LogoutAll( - req *http.Request, deviceDB *devices.Database, device *authtypes.Device, + req *http.Request, deviceDB devices.Database, device *authtypes.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 8b8b3a0f..68c131a2 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -40,7 +40,7 @@ var errMissingUserID = errors.New("'user_id' must be supplied") // SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite) // by building a m.room.member event then sending it to the room server func SendMembership( - req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *authtypes.Device, roomID string, membership string, cfg *config.Dendrite, queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI, producer *producers.RoomserverProducer, @@ -116,7 +116,7 @@ func SendMembership( func buildMembershipEvent( ctx context.Context, - body threepid.MembershipRequest, accountDB *accounts.Database, + body threepid.MembershipRequest, accountDB accounts.Database, device *authtypes.Device, membership, roomID string, cfg *config.Dendrite, evTime time.Time, @@ -166,7 +166,7 @@ func loadProfile( ctx context.Context, userID string, cfg *config.Dendrite, - accountDB *accounts.Database, + accountDB accounts.Database, asAPI appserviceAPI.AppServiceQueryAPI, ) (*authtypes.Profile, error) { _, serverName, err := gomatrixserverlib.SplitID('@', userID) @@ -216,7 +216,7 @@ func checkAndProcessThreepid( body *threepid.MembershipRequest, cfg *config.Dendrite, queryAPI roomserverAPI.RoomserverQueryAPI, - accountDB *accounts.Database, + accountDB accounts.Database, producer *producers.RoomserverProducer, membership, roomID string, evTime time.Time, diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 4688b19e..9b091ddf 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -36,7 +36,7 @@ import ( // GetProfile implements GET /profile/{userID} func GetProfile( - req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, + req *http.Request, accountDB accounts.Database, cfg *config.Dendrite, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, @@ -64,7 +64,7 @@ func GetProfile( // GetAvatarURL implements GET /profile/{userID}/avatar_url func GetAvatarURL( - req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, + req *http.Request, accountDB accounts.Database, cfg *config.Dendrite, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -90,7 +90,7 @@ func GetAvatarURL( // SetAvatarURL implements PUT /profile/{userID}/avatar_url func SetAvatarURL( - req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *authtypes.Device, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, ) util.JSONResponse { @@ -170,7 +170,7 @@ func SetAvatarURL( // GetDisplayName implements GET /profile/{userID}/displayname func GetDisplayName( - req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite, + req *http.Request, accountDB accounts.Database, cfg *config.Dendrite, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, ) util.JSONResponse { @@ -196,7 +196,7 @@ func GetDisplayName( // SetDisplayName implements PUT /profile/{userID}/displayname func SetDisplayName( - req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *authtypes.Device, userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite, rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI, ) util.JSONResponse { @@ -279,7 +279,7 @@ func SetDisplayName( // Returns an error when something goes wrong or specifically // common.ErrProfileNoExists when the profile doesn't exist. func getProfile( - ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite, + ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite, userID string, asAPI appserviceAPI.AppServiceQueryAPI, federation *gomatrixserverlib.FederationClient, diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 4375faaf..9d67d998 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -440,8 +440,8 @@ func validateApplicationService( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register func Register( req *http.Request, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, cfg *config.Dendrite, ) util.JSONResponse { @@ -513,8 +513,8 @@ func handleGuestRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, ) util.JSONResponse { //Generate numeric local part for guest user @@ -570,8 +570,8 @@ func handleRegistrationFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, ) util.JSONResponse { // TODO: Shared secret registration (create new user scripts) // TODO: Enable registration config flag @@ -668,8 +668,8 @@ func handleApplicationServiceRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, ) util.JSONResponse { // Check if we previously had issues extracting the access token from the // request. @@ -707,8 +707,8 @@ func checkAndCompleteFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue @@ -730,8 +730,8 @@ func checkAndCompleteFlow( // LegacyRegister process register requests from the legacy v1 API func LegacyRegister( req *http.Request, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, cfg *config.Dendrite, ) util.JSONResponse { var r legacyRegisterRequest @@ -814,8 +814,8 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u // not all func completeRegistration( ctx context.Context, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, username, password, appserviceID string, inhibitLogin common.WeakBoolean, displayName, deviceID *string, @@ -992,7 +992,7 @@ type availableResponse struct { func RegisterAvailable( req *http.Request, cfg *config.Dendrite, - accountDB *accounts.Database, + accountDB accounts.Database, ) util.JSONResponse { username := req.URL.Query().Get("username") diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index 487081c5..aa5f13c4 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -40,7 +40,7 @@ func newTag() gomatrix.TagContent { // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags func GetTags( req *http.Request, - accountDB *accounts.Database, + accountDB accounts.Database, device *authtypes.Device, userID string, roomID string, @@ -77,7 +77,7 @@ func GetTags( // the tag to the "map" and saving the new "map" to the DB func PutTag( req *http.Request, - accountDB *accounts.Database, + accountDB accounts.Database, device *authtypes.Device, userID string, roomID string, @@ -134,7 +134,7 @@ func PutTag( // the "map" and then saving the new "map" in the DB func DeleteTag( req *http.Request, - accountDB *accounts.Database, + accountDB accounts.Database, device *authtypes.Device, userID string, roomID string, @@ -203,7 +203,7 @@ func obtainSavedTags( req *http.Request, userID string, roomID string, - accountDB *accounts.Database, + accountDB accounts.Database, ) (string, *gomatrixserverlib.ClientEvent, error) { localpart, _, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -222,7 +222,7 @@ func saveTagData( req *http.Request, localpart string, roomID string, - accountDB *accounts.Database, + accountDB accounts.Database, Tag gomatrix.TagContent, ) error { newTagData, err := json.Marshal(Tag) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f519523a..f0841b79 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -52,8 +52,8 @@ func Setup( queryAPI roomserverAPI.RoomserverQueryAPI, aliasAPI roomserverAPI.RoomserverAliasAPI, asAPI appserviceAPI.AppServiceQueryAPI, - accountDB *accounts.Database, - deviceDB *devices.Database, + accountDB accounts.Database, + deviceDB devices.Database, federation *gomatrixserverlib.FederationClient, keyRing gomatrixserverlib.KeyRing, userUpdateProducer *producers.UserUpdateProducer, diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go index 561a2d89..db3ab28b 100644 --- a/clientapi/routing/sendtyping.go +++ b/clientapi/routing/sendtyping.go @@ -34,7 +34,7 @@ type typingContentJSON struct { // sends the typing events to client API typingProducer func SendTyping( req *http.Request, device *authtypes.Device, roomID string, - userID string, accountDB *accounts.Database, + userID string, accountDB accounts.Database, typingProducer *producers.TypingServerProducer, ) util.JSONResponse { if device.UserID != userID { diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 88b02fe4..69383cdf 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -39,7 +39,7 @@ type threePIDsResponse struct { // RequestEmailToken implements: // POST /account/3pid/email/requestToken // POST /register/email/requestToken -func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite) util.JSONResponse { +func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.Dendrite) util.JSONResponse { var body threepid.EmailAssociationRequest if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr @@ -82,7 +82,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg *con // CheckAndSave3PIDAssociation implements POST /account/3pid func CheckAndSave3PIDAssociation( - req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *authtypes.Device, cfg *config.Dendrite, ) util.JSONResponse { var body threepid.EmailAssociationCheckRequest @@ -142,7 +142,7 @@ func CheckAndSave3PIDAssociation( // GetAssociated3PIDs implements GET /account/3pid func GetAssociated3PIDs( - req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + req *http.Request, accountDB accounts.Database, device *authtypes.Device, ) util.JSONResponse { localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { @@ -161,7 +161,7 @@ func GetAssociated3PIDs( } // Forget3PID implements POST /account/3pid/delete -func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONResponse { +func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse { var body authtypes.ThreePID if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil { return *reqErr diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 2cf88d6e..aa54aa9f 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -87,7 +87,7 @@ var ( func CheckAndProcessInvite( ctx context.Context, device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite, - queryAPI api.RoomserverQueryAPI, db *accounts.Database, + queryAPI api.RoomserverQueryAPI, db accounts.Database, producer *producers.RoomserverProducer, membership string, roomID string, evTime time.Time, ) (inviteStoredOnIDServer bool, err error) { @@ -137,7 +137,7 @@ func CheckAndProcessInvite( // Returns an error if a check or a request failed. func queryIDServer( ctx context.Context, - db *accounts.Database, cfg *config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, body *MembershipRequest, roomID string, ) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) { if err = isTrusted(body.IDServer, cfg); err != nil { @@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe // Returns an error if the request failed to send or if the response couldn't be parsed. func queryIDServerStoreInvite( ctx context.Context, - db *accounts.Database, cfg *config.Dendrite, device *authtypes.Device, + db accounts.Database, cfg *config.Dendrite, device *authtypes.Device, body *MembershipRequest, roomID string, ) (*idServerStoreInviteResponse, error) { // Retrieve the sender's profile to get their display name |