diff options
author | Kegsay <kegan@matrix.org> | 2020-06-17 12:05:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-17 12:05:56 +0100 |
commit | e09d24e7323e73791e7bb31fa7fac1d3acf0c299 (patch) | |
tree | 85600bfe84e938ffeb48d7a152f968f86b78fcdf /userapi | |
parent | 5d5aa0a31d60941c7ece95b4b516044cb8a10cce (diff) |
Move account/device DBs to userapi (#1141)
Diffstat (limited to 'userapi')
29 files changed, 3718 insertions, 6 deletions
diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 3a413166..ae021f57 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -21,12 +21,12 @@ import ( "fmt" "github.com/matrix-org/dendrite/appservice/types" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go new file mode 100644 index 00000000..13e3e289 --- /dev/null +++ b/userapi/storage/accounts/interface.go @@ -0,0 +1,62 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package accounts + +import ( + "context" + "errors" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + internal.PartitionStorer + GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) + GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error + SetDisplayName(ctx context.Context, localpart string, displayName string) error + // CreateAccount makes a new account with the given login name and password, and creates an empty profile + // for this account. If no password is supplied, the account will be a passwordless account. If the + // account already exists, it will return nil, ErrUserExists. + CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) + CreateGuestAccount(ctx context.Context) (*api.Account, error) + UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error + GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) + GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) + GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) + SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error + GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) + // GetAccountDataByType returns account data matching a given + // localpart, room ID and type. + // If no account data could be found, returns nil + // Returns an error if there was an issue with the retrieval + GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) + GetNewNumericLocalpart(ctx context.Context) (int64, error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) + PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) + CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("This third-party identifier is already in use") diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go new file mode 100644 index 00000000..2f16c5c0 --- /dev/null +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -0,0 +1,142 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/gomatrixserverlib" +) + +const accountDataSchema = ` +-- Stores data about accounts data. +CREATE TABLE IF NOT EXISTS account_data ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL, + -- The room ID for this data (empty string if not specific to a room) + room_id TEXT, + -- The account data type + type TEXT NOT NULL, + -- The account data content + content TEXT NOT NULL, + + PRIMARY KEY(localpart, room_id, type) +); +` + +const insertAccountDataSQL = ` + INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content +` + +const selectAccountDataSQL = "" + + "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + +const selectAccountDataByTypeSQL = "" + + "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + +type accountDataStatements struct { + insertAccountDataStmt *sql.Stmt + selectAccountDataStmt *sql.Stmt + selectAccountDataByTypeStmt *sql.Stmt +} + +func (s *accountDataStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(accountDataSchema) + if err != nil { + return + } + if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { + return + } + if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { + return + } + if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { + return + } + return +} + +func (s *accountDataStatements) insertAccountData( + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, +) (err error) { + stmt := txn.Stmt(s.insertAccountDataStmt) + _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) + return +} + +func (s *accountDataStatements) selectAccountData( + ctx context.Context, localpart string, +) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") + + global = []gomatrixserverlib.ClientEvent{} + rooms = make(map[string][]gomatrixserverlib.ClientEvent) + + for rows.Next() { + var roomID string + var dataType string + var content []byte + + if err = rows.Scan(&roomID, &dataType, &content); err != nil { + return + } + + ac := gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, + } + + if len(roomID) > 0 { + rooms[roomID] = append(rooms[roomID], ac) + } else { + global = append(global, ac) + } + } + return global, rooms, rows.Err() +} + +func (s *accountDataStatements) selectAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data *gomatrixserverlib.ClientEvent, err error) { + stmt := s.selectAccountDataByTypeStmt + var content []byte + + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + + return + } + + data = &gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, + } + + return +} diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go new file mode 100644 index 00000000..931ffb73 --- /dev/null +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -0,0 +1,157 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + + log "github.com/sirupsen/logrus" +) + +const accountsSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_accounts ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- When this account was first created, as a unix timestamp (ms resolution). + created_ts BIGINT NOT NULL, + -- The password hash for this account. Can be NULL if this is a passwordless account. + password_hash TEXT, + -- Identifies which application service this account belongs to, if any. + appservice_id TEXT + -- TODO: + -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? +); +-- Create sequence for autogenerated numeric usernames +CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; +` + +const insertAccountSQL = "" + + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + +const selectAccountByLocalpartSQL = "" + + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + +const selectPasswordHashSQL = "" + + "SELECT password_hash FROM account_accounts WHERE localpart = $1" + +const selectNewNumericLocalpartSQL = "" + + "SELECT nextval('numeric_username_seq')" + +// TODO: Update password + +type accountsStatements struct { + insertAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(accountsSchema) + if err != nil { + return + } + if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { + return + } + if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { + return + } + if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { + return + } + if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, +// this account will be passwordless. Returns an error if this account already exists. Returns the account +// on success. +func (s *accountsStatements) insertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, +) (*api.Account, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + stmt := txn.Stmt(s.insertAccountStmt) + + var err error + if appserviceID == "" { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } + if err != nil { + return nil, err + } + + return &api.Account{ + Localpart: localpart, + UserID: userutil.MakeUserID(localpart, s.serverName), + ServerName: s.serverName, + AppServiceID: appserviceID, + }, nil +} + +func (s *accountsStatements) selectPasswordHash( + ctx context.Context, localpart string, +) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + return +} + +func (s *accountsStatements) selectAccountByLocalpart( + ctx context.Context, localpart string, +) (*api.Account, error) { + var appserviceIDPtr sql.NullString + var acc api.Account + + stmt := s.selectAccountByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve user from the db") + } + return nil, err + } + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil +} + +func (s *accountsStatements) selectNewNumericLocalpart( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = txn.Stmt(stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) + return +} diff --git a/userapi/storage/accounts/postgres/filter_table.go b/userapi/storage/accounts/postgres/filter_table.go new file mode 100644 index 00000000..c54e4bc4 --- /dev/null +++ b/userapi/storage/accounts/postgres/filter_table.go @@ -0,0 +1,127 @@ +// Copyright 2017 Jan Christian Grünhage +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +const filterSchema = ` +-- Stores data about filters +CREATE TABLE IF NOT EXISTS account_filter ( + -- The filter + filter TEXT NOT NULL, + -- The ID + id SERIAL UNIQUE, + -- The localpart of the Matrix user ID associated to this filter + localpart TEXT NOT NULL, + + PRIMARY KEY(id, localpart) +); + +CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart); +` + +const selectFilterSQL = "" + + "SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2" + +const selectFilterIDByContentSQL = "" + + "SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2" + +const insertFilterSQL = "" + + "INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id" + +type filterStatements struct { + selectFilterStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt +} + +func (s *filterStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(filterSchema) + if err != nil { + return + } + if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { + return + } + if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { + return + } + if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { + return + } + return +} + +func (s *filterStatements) selectFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + // Retrieve filter from database (stored as canonical JSON) + var filterData []byte + err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + if err != nil { + return nil, err + } + + // Unmarshal JSON into Filter struct + var filter gomatrixserverlib.Filter + if err = json.Unmarshal(filterData, &filter); err != nil { + return nil, err + } + return &filter, nil +} + +func (s *filterStatements) insertFilter( + ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, +) (filterID string, err error) { + var existingFilterID string + + // Serialise json + filterJSON, err := json.Marshal(filter) + if err != nil { + return "", err + } + // Remove whitespaces and sort JSON data + // needed to prevent from inserting the same filter multiple times + filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON) + if err != nil { + return "", err + } + + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // If it does, return the existing ID + if existingFilterID != "" { + return existingFilterID, err + } + + // Otherwise insert the filter and return the new ID + err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart). + Scan(&filterID) + return +} diff --git a/userapi/storage/accounts/postgres/membership_table.go b/userapi/storage/accounts/postgres/membership_table.go new file mode 100644 index 00000000..623530ac --- /dev/null +++ b/userapi/storage/accounts/postgres/membership_table.go @@ -0,0 +1,159 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" +) + +const membershipSchema = ` +-- Stores data about users memberships to rooms. +CREATE TABLE IF NOT EXISTS account_memberships ( + -- The Matrix user ID localpart for the member + localpart TEXT NOT NULL, + -- The room this user is a member of + room_id TEXT NOT NULL, + -- The ID of the join membership event + event_id TEXT NOT NULL, + + -- A user can only be member of a room once + PRIMARY KEY (localpart, room_id) +); + +-- Use index to process deletion by ID more efficiently +CREATE UNIQUE INDEX IF NOT EXISTS account_membership_event_id ON account_memberships(event_id); +` + +const insertMembershipSQL = ` + INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3) + ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id +` + +const selectMembershipsByLocalpartSQL = "" + + "SELECT room_id, event_id FROM account_memberships WHERE localpart = $1" + +const selectMembershipInRoomByLocalpartSQL = "" + + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" + +const selectRoomIDsByLocalPartSQL = "" + + "SELECT room_id FROM account_memberships WHERE localpart = $1" + +const deleteMembershipsByEventIDsSQL = "" + + "DELETE FROM account_memberships WHERE event_id = ANY($1)" + +type membershipStatements struct { + deleteMembershipsByEventIDsStmt *sql.Stmt + insertMembershipStmt *sql.Stmt + selectMembershipInRoomByLocalpartStmt *sql.Stmt + selectMembershipsByLocalpartStmt *sql.Stmt + selectRoomIDsByLocalPartStmt *sql.Stmt +} + +func (s *membershipStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(membershipSchema) + if err != nil { + return + } + if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil { + return + } + if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil { + return + } + if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil { + return + } + if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { + return + } + if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil { + return + } + return +} + +func (s *membershipStatements) insertMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) (err error) { + stmt := txn.Stmt(s.insertMembershipStmt) + _, err = stmt.ExecContext(ctx, localpart, roomID, eventID) + return +} + +func (s *membershipStatements) deleteMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (err error) { + stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt) + _, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs)) + return +} + +func (s *membershipStatements) selectMembershipInRoomByLocalpart( + ctx context.Context, localpart, roomID string, +) (authtypes.Membership, error) { + membership := authtypes.Membership{Localpart: localpart, RoomID: roomID} + stmt := s.selectMembershipInRoomByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID) + + return membership, err +} + +func (s *membershipStatements) selectMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + stmt := s.selectMembershipsByLocalpartStmt + rows, err := stmt.QueryContext(ctx, localpart) + if err != nil { + return + } + + memberships = []authtypes.Membership{} + + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed") + for rows.Next() { + var m authtypes.Membership + m.Localpart = localpart + if err = rows.Scan(&m.RoomID, &m.EventID); err != nil { + return + } + memberships = append(memberships, m) + } + return memberships, rows.Err() +} + +func (s *membershipStatements) selectRoomIDsByLocalPart( + ctx context.Context, localPart string, +) ([]string, error) { + stmt := s.selectRoomIDsByLocalPartStmt + rows, err := stmt.QueryContext(ctx, localPart) + if err != nil { + return nil, err + } + roomIDs := []string{} + defer rows.Close() // nolint: errcheck + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go new file mode 100644 index 00000000..d2cbeb8e --- /dev/null +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -0,0 +1,107 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const profilesSchema = ` +-- Stores data about accounts profiles. +CREATE TABLE IF NOT EXISTS account_profiles ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- The display name for this account + display_name TEXT, + -- The URL of the avatar for this account + avatar_url TEXT +); +` + +const insertProfileSQL = "" + + "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + +const selectProfileByLocalpartSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + +const setAvatarURLSQL = "" + + "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + +const setDisplayNameSQL = "" + + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + +type profilesStatements struct { + insertProfileStmt *sql.Stmt + selectProfileByLocalpartStmt *sql.Stmt + setAvatarURLStmt *sql.Stmt + setDisplayNameStmt *sql.Stmt +} + +func (s *profilesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(profilesSchema) + if err != nil { + return + } + if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { + return + } + if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { + return + } + if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { + return + } + if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { + return + } + return +} + +func (s *profilesStatements) insertProfile( + ctx context.Context, txn *sql.Tx, localpart string, +) (err error) { + _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + return +} + +func (s *profilesStatements) selectProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + var profile authtypes.Profile + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( + &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + ) + if err != nil { + return nil, err + } + return &profile, nil +} + +func (s *profilesStatements) setAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) (err error) { + _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) + return +} + +func (s *profilesStatements) setDisplayName( + ctx context.Context, localpart string, displayName string, +) (err error) { + _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) + return +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go new file mode 100644 index 00000000..2b88cb70 --- /dev/null +++ b/userapi/storage/accounts/postgres/storage.go @@ -0,0 +1,433 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "strconv" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + + // Import the postgres database driver. + _ "github.com/lib/pq" +) + +// Database represents an account database +type Database struct { + db *sql.DB + sqlutil.PartitionOffsetStatements + accounts accountsStatements + profiles profilesStatements + memberships membershipStatements + accountDatas accountDataStatements + threepids threepidStatements + filter filterStatements + serverName gomatrixserverlib.ServerName +} + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { + return nil, err + } + partitions := sqlutil.PartitionOffsetStatements{} + if err = partitions.Prepare(db, "account"); err != nil { + return nil, err + } + a := accountsStatements{} + if err = a.prepare(db, serverName); err != nil { + return nil, err + } + p := profilesStatements{} + if err = p.prepare(db); err != nil { + return nil, err + } + m := membershipStatements{} + if err = m.prepare(db); err != nil { + return nil, err + } + ac := accountDataStatements{} + if err = ac.prepare(db); err != nil { + return nil, err + } + t := threepidStatements{} + if err = t.prepare(db); err != nil { + return nil, err + } + f := filterStatements{} + if err = f.prepare(db); err != nil { + return nil, err + } + return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil +} + +// GetAccountByPassword returns the account associated with the given localpart and password. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*api.Account, error) { + hash, err := d.accounts.selectPasswordHash(ctx, localpart) + if err != nil { + return nil, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} + +// GetProfileByLocalpart returns the profile associated with the given localpart. +// Returns sql.ErrNoRows if no profile exists which matches the given localpart. +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.profiles.selectProfileByLocalpart(ctx, localpart) +} + +// SetAvatarURL updates the avatar URL of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + return d.profiles.setAvatarURL(ctx, localpart, avatarURL) +} + +// SetDisplayName updates the display name of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + return d.profiles.setDisplayName(ctx, localpart, displayName) +} + +// CreateGuestAccount makes a new guest account and creates an empty profile +// for this account. +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart := strconv.FormatInt(numLocalpart, 10) + acc, err = d.createAccount(ctx, txn, localpart, "", "") + return err + }) + return acc, err +} + +// CreateAccount makes a new account with the given login name and password, and creates an empty profile +// for this account. If no password is supplied, the account will be a passwordless account. If the +// account already exists, it will return nil, sqlutil.ErrUserExists. +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + return err + }) + return +} + +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, +) (*api.Account, error) { + var err error + + // Generate a password hash if this is not a password-less user + hash := "" + if plaintextPassword != "" { + hash, err = hashPassword(plaintextPassword) + if err != nil { + return nil, err + } + } + if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { + if sqlutil.IsUniqueConstraintViolationErr(err) { + return nil, sqlutil.ErrUserExists + } + return nil, err + } + + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`); err != nil { + return nil, err + } + return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) +} + +// SaveMembership saves the user matching a given localpart as a member of a given +// room. It also stores the ID of the membership event. +// If a membership already exists between the user and the room, or if the +// insert fails, returns the SQL error +func (d *Database) saveMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) error { + return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID) +} + +// removeMembershipsByEventIDs removes the memberships corresponding to the +// `join` membership events IDs in the eventIDs slice. +// If the removal fails, or if there is no membership to remove, returns an error +func (d *Database) removeMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) error { + return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs) +} + +// UpdateMemberships adds the "join" membership events included in a given state +// events array, and removes those which ID is included in a given array of events +// IDs. All of the process is run in a transaction, which commits only once/if every +// insertion and deletion has been successfully processed. +// Returns a SQL error if there was an issue with any part of the process +func (d *Database) UpdateMemberships( + ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { + return err + } + + for _, event := range eventsToAdd { + if err := d.newMembership(ctx, txn, event); err != nil { + return err + } + } + + return nil + }) +} + +// GetMembershipInRoomByLocalpart returns the membership for an user +// matching the given localpart if he is a member of the room matching roomID, +// if not sql.ErrNoRows is returned. +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipInRoomByLocalpart( + ctx context.Context, localpart, roomID string, +) (authtypes.Membership, error) { + return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) +} + +// GetRoomIDsByLocalPart returns an array containing the room ids of all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetRoomIDsByLocalPart( + ctx context.Context, localpart string, +) ([]string, error) { + return d.memberships.selectRoomIDsByLocalPart(ctx, localpart) +} + +// GetMembershipsByLocalpart returns an array containing the memberships for all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + return d.memberships.selectMembershipsByLocalpart(ctx, localpart) +} + +// newMembership saves a new membership in the database. +// If the event isn't a valid m.room.member event with type `join`, does nothing. +// If an error occurred, returns the SQL error +func (d *Database) newMembership( + ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event, +) error { + if ev.Type() == "m.room.member" && ev.StateKey() != nil { + localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + return err + } + + // We only want state events from local users + if string(serverName) != string(d.serverName) { + return nil + } + + eventID := ev.EventID() + roomID := ev.RoomID() + membership, err := ev.Membership() + if err != nil { + return err + } + + // Only "join" membership events can be considered as new memberships + if membership == gomatrixserverlib.Join { + if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { + return err + } + } + } + return nil +} + +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType, content string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + return d.accountDatas.selectAccountData(ctx, localpart) +} + +// GetAccountDataByType returns account data matching a given +// localpart, room ID and type. +// If no account data could be found, returns nil +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data *gomatrixserverlib.ClientEvent, err error) { + return d.accountDatas.selectAccountDataByType( + ctx, localpart, roomID, dataType, + ) +} + +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.accounts.selectNewNumericLocalpart(ctx, nil) +} + +func hashPassword(plaintext string) (hash string, err error) { + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) + return string(hashBytes), err +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("This third-party identifier is already in use") + +// SaveThreePIDAssociation saves the association between a third party identifier +// and a local Matrix user (identified by the user's ID's local part). +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + user, err := d.threepids.selectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) + if err != nil { + return err + } + + if len(user) > 0 { + return Err3PIDInUse + } + + return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) + }) +} + +// RemoveThreePIDAssociation removes the association involving a given third-party +// identifier. +// If no association exists involving this third-party identifier, returns nothing. +// If there was a problem talking to the database, returns an error. +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + return d.threepids.deleteThreePID(ctx, threepid, medium) +} + +// GetLocalpartForThreePID looks up the localpart associated with a given third-party +// identifier. +// If no association involves the given third-party idenfitier, returns an empty +// string. +// Returns an error if there was a problem talking to the database. +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) +} + +// GetThreePIDsForLocalpart looks up the third-party identifiers associated with +// a given local user. +// If no association is known for this user, returns an empty slice. +// Returns an error if there was an issue talking to the database. +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) +} + +// GetFilter looks up the filter associated with a given local user and filter ID. +// Returns a filter structure. Otherwise returns an error if no such filter exists +// or if there was an error talking to the database. +func (d *Database) GetFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + return d.filter.selectFilter(ctx, localpart, filterID) +} + +// PutFilter puts the passed filter into the database. +// Returns the filterID as a string. Otherwise returns an error if something +// goes wrong. +func (d *Database) PutFilter( + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, +) (string, error) { + return d.filter.insertFilter(ctx, filter, localpart) +} + +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} + +// GetAccountByLocalpart returns the account associated with the given localpart. +// This function assumes the request is authenticated or the account data is used only internally. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +) (*api.Account, error) { + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/accounts/postgres/threepid_table.go new file mode 100644 index 00000000..7de96350 --- /dev/null +++ b/userapi/storage/accounts/postgres/threepid_table.go @@ -0,0 +1,129 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const threepidSchema = ` +-- Stores data about third party identifiers +CREATE TABLE IF NOT EXISTS account_threepid ( + -- The third party identifier + threepid TEXT NOT NULL, + -- The 3PID medium + medium TEXT NOT NULL DEFAULT 'email', + -- The localpart of the Matrix user ID associated to this 3PID + localpart TEXT NOT NULL, + + PRIMARY KEY(threepid, medium) +); + +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +` + +const selectLocalpartForThreePIDSQL = "" + + "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + +const selectThreePIDsForLocalpartSQL = "" + + "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + +const insertThreePIDSQL = "" + + "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + +const deleteThreePIDSQL = "" + + "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + +type threepidStatements struct { + selectLocalpartForThreePIDStmt *sql.Stmt + selectThreePIDsForLocalpartStmt *sql.Stmt + insertThreePIDStmt *sql.Stmt + deleteThreePIDStmt *sql.Stmt +} + +func (s *threepidStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(threepidSchema) + if err != nil { + return + } + if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { + return + } + if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { + return + } + if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { + return + } + if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { + return + } + + return +} + +func (s *threepidStatements) selectLocalpartForThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string, +) (localpart string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *threepidStatements) selectThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + if err != nil { + return + } + + threepids = []authtypes.ThreePID{} + for rows.Next() { + var threepid string + var medium string + if err = rows.Scan(&threepid, &medium); err != nil { + return + } + threepids = append(threepids, authtypes.ThreePID{ + Address: threepid, + Medium: medium, + }) + } + + return +} + +func (s *threepidStatements) insertThreePID( + ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + return +} + +func (s *threepidStatements) deleteThreePID( + ctx context.Context, threepid string, medium string) (err error) { + _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) + return +} diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go new file mode 100644 index 00000000..b6bb6361 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -0,0 +1,140 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +const accountDataSchema = ` +-- Stores data about accounts data. +CREATE TABLE IF NOT EXISTS account_data ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL, + -- The room ID for this data (empty string if not specific to a room) + room_id TEXT, + -- The account data type + type TEXT NOT NULL, + -- The account data content + content TEXT NOT NULL, + + PRIMARY KEY(localpart, room_id, type) +); +` + +const insertAccountDataSQL = ` + INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 +` + +const selectAccountDataSQL = "" + + "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + +const selectAccountDataByTypeSQL = "" + + "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + +type accountDataStatements struct { + insertAccountDataStmt *sql.Stmt + selectAccountDataStmt *sql.Stmt + selectAccountDataByTypeStmt *sql.Stmt +} + +func (s *accountDataStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(accountDataSchema) + if err != nil { + return + } + if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { + return + } + if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { + return + } + if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { + return + } + return +} + +func (s *accountDataStatements) insertAccountData( + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, +) (err error) { + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + return +} + +func (s *accountDataStatements) selectAccountData( + ctx context.Context, localpart string, +) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + if err != nil { + return + } + + global = []gomatrixserverlib.ClientEvent{} + rooms = make(map[string][]gomatrixserverlib.ClientEvent) + + for rows.Next() { + var roomID string + var dataType string + var content []byte + + if err = rows.Scan(&roomID, &dataType, &content); err != nil { + return + } + + ac := gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, + } + + if len(roomID) > 0 { + rooms[roomID] = append(rooms[roomID], ac) + } else { + global = append(global, ac) + } + } + + return +} + +func (s *accountDataStatements) selectAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data *gomatrixserverlib.ClientEvent, err error) { + stmt := s.selectAccountDataByTypeStmt + var content []byte + + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + + return + } + + data = &gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, + } + + return +} diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go new file mode 100644 index 00000000..768f536d --- /dev/null +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -0,0 +1,155 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + + log "github.com/sirupsen/logrus" +) + +const accountsSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_accounts ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- When this account was first created, as a unix timestamp (ms resolution). + created_ts BIGINT NOT NULL, + -- The password hash for this account. Can be NULL if this is a passwordless account. + password_hash TEXT, + -- Identifies which application service this account belongs to, if any. + appservice_id TEXT + -- TODO: + -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? +); +` + +const insertAccountSQL = "" + + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + +const selectAccountByLocalpartSQL = "" + + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + +const selectPasswordHashSQL = "" + + "SELECT password_hash FROM account_accounts WHERE localpart = $1" + +const selectNewNumericLocalpartSQL = "" + + "SELECT COUNT(localpart) FROM account_accounts" + +// TODO: Update password + +type accountsStatements struct { + insertAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(accountsSchema) + if err != nil { + return + } + if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { + return + } + if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { + return + } + if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { + return + } + if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, +// this account will be passwordless. Returns an error if this account already exists. Returns the account +// on success. +func (s *accountsStatements) insertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, +) (*api.Account, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + stmt := s.insertAccountStmt + + var err error + if appserviceID == "" { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } + if err != nil { + return nil, err + } + + return &api.Account{ + Localpart: localpart, + UserID: userutil.MakeUserID(localpart, s.serverName), + ServerName: s.serverName, + AppServiceID: appserviceID, + }, nil +} + +func (s *accountsStatements) selectPasswordHash( + ctx context.Context, localpart string, +) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + return +} + +func (s *accountsStatements) selectAccountByLocalpart( + ctx context.Context, localpart string, +) (*api.Account, error) { + var appserviceIDPtr sql.NullString + var acc api.Account + + stmt := s.selectAccountByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve user from the db") + } + return nil, err + } + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil +} + +func (s *accountsStatements) selectNewNumericLocalpart( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = txn.Stmt(stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) + return +} diff --git a/userapi/storage/accounts/sqlite3/constraint.go b/userapi/storage/accounts/sqlite3/constraint.go new file mode 100644 index 00000000..32f96c8e --- /dev/null +++ b/userapi/storage/accounts/sqlite3/constraint.go @@ -0,0 +1,27 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !wasm + +package sqlite3 + +import ( + "errors" + + "github.com/mattn/go-sqlite3" +) + +func isConstraintError(err error) bool { + return errors.Is(err, sqlite3.ErrConstraint) +} diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/accounts/sqlite3/constraint_wasm.go new file mode 100644 index 00000000..0dd5b1fe --- /dev/null +++ b/userapi/storage/accounts/sqlite3/constraint_wasm.go @@ -0,0 +1,21 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build wasm + +package sqlite3 + +func isConstraintError(err error) bool { + return false +} diff --git a/userapi/storage/accounts/sqlite3/filter_table.go b/userapi/storage/accounts/sqlite3/filter_table.go new file mode 100644 index 00000000..7f1a0c24 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/filter_table.go @@ -0,0 +1,135 @@ +// Copyright 2017 Jan Christian Grünhage +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/gomatrixserverlib" +) + +const filterSchema = ` +-- Stores data about filters +CREATE TABLE IF NOT EXISTS account_filter ( + -- The filter + filter TEXT NOT NULL, + -- The ID + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The localpart of the Matrix user ID associated to this filter + localpart TEXT NOT NULL, + + UNIQUE (id, localpart) +); + +CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart); +` + +const selectFilterSQL = "" + + "SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2" + +const selectFilterIDByContentSQL = "" + + "SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2" + +const insertFilterSQL = "" + + "INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)" + +type filterStatements struct { + selectFilterStmt *sql.Stmt + selectFilterIDByContentStmt *sql.Stmt + insertFilterStmt *sql.Stmt +} + +func (s *filterStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(filterSchema) + if err != nil { + return + } + if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { + return + } + if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil { + return + } + if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { + return + } + return +} + +func (s *filterStatements) selectFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + // Retrieve filter from database (stored as canonical JSON) + var filterData []byte + err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + if err != nil { + return nil, err + } + + // Unmarshal JSON into Filter struct + var filter gomatrixserverlib.Filter + if err = json.Unmarshal(filterData, &filter); err != nil { + return nil, err + } + return &filter, nil +} + +func (s *filterStatements) insertFilter( + ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, +) (filterID string, err error) { + var existingFilterID string + + // Serialise json + filterJSON, err := json.Marshal(filter) + if err != nil { + return "", err + } + // Remove whitespaces and sort JSON data + // needed to prevent from inserting the same filter multiple times + filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON) + if err != nil { + return "", err + } + + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // If it does, return the existing ID + if existingFilterID != "" { + return existingFilterID, err + } + + // Otherwise insert the filter and return the new ID + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { + return "", err + } + rowid, err := res.LastInsertId() + if err != nil { + return "", err + } + filterID = fmt.Sprintf("%d", rowid) + return +} diff --git a/userapi/storage/accounts/sqlite3/membership_table.go b/userapi/storage/accounts/sqlite3/membership_table.go new file mode 100644 index 00000000..67958f27 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/membership_table.go @@ -0,0 +1,159 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const membershipSchema = ` +-- Stores data about users memberships to rooms. +CREATE TABLE IF NOT EXISTS account_memberships ( + -- The Matrix user ID localpart for the member + localpart TEXT NOT NULL, + -- The room this user is a member of + room_id TEXT NOT NULL, + -- The ID of the join membership event + event_id TEXT NOT NULL, + + -- A user can only be member of a room once + PRIMARY KEY (localpart, room_id), + + UNIQUE (event_id) +); +` + +const insertMembershipSQL = ` + INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3) + ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id +` + +const selectMembershipsByLocalpartSQL = "" + + "SELECT room_id, event_id FROM account_memberships WHERE localpart = $1" + +const selectMembershipInRoomByLocalpartSQL = "" + + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" + +const selectRoomIDsByLocalPartSQL = "" + + "SELECT room_id FROM account_memberships WHERE localpart = $1" + +const deleteMembershipsByEventIDsSQL = "" + + "DELETE FROM account_memberships WHERE event_id IN ($1)" + +type membershipStatements struct { + insertMembershipStmt *sql.Stmt + selectMembershipInRoomByLocalpartStmt *sql.Stmt + selectMembershipsByLocalpartStmt *sql.Stmt + selectRoomIDsByLocalPartStmt *sql.Stmt +} + +func (s *membershipStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(membershipSchema) + if err != nil { + return + } + if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil { + return + } + if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil { + return + } + if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { + return + } + if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil { + return + } + return +} + +func (s *membershipStatements) insertMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) (err error) { + stmt := txn.Stmt(s.insertMembershipStmt) + _, err = stmt.ExecContext(ctx, localpart, roomID, eventID) + return +} + +func (s *membershipStatements) deleteMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) (err error) { + sqlStr := strings.Replace(deleteMembershipsByEventIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1) + iEventIDs := make([]interface{}, len(eventIDs)) + for i, e := range eventIDs { + iEventIDs[i] = e + } + _, err = txn.ExecContext(ctx, sqlStr, iEventIDs...) + return +} + +func (s *membershipStatements) selectMembershipInRoomByLocalpart( + ctx context.Context, localpart, roomID string, +) (authtypes.Membership, error) { + membership := authtypes.Membership{Localpart: localpart, RoomID: roomID} + stmt := s.selectMembershipInRoomByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID) + + return membership, err +} + +func (s *membershipStatements) selectMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + stmt := s.selectMembershipsByLocalpartStmt + rows, err := stmt.QueryContext(ctx, localpart) + if err != nil { + return + } + + memberships = []authtypes.Membership{} + + defer internal.CloseAndLogIfError(ctx, rows, "selectMembershipsByLocalpart: rows.close() failed") + for rows.Next() { + var m authtypes.Membership + m.Localpart = localpart + if err := rows.Scan(&m.RoomID, &m.EventID); err != nil { + return nil, err + } + memberships = append(memberships, m) + } + + return +} +func (s *membershipStatements) selectRoomIDsByLocalPart( + ctx context.Context, localPart string, +) ([]string, error) { + stmt := s.selectRoomIDsByLocalPartStmt + rows, err := stmt.QueryContext(ctx, localPart) + if err != nil { + return nil, err + } + roomIDs := []string{} + defer rows.Close() // nolint: errcheck + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go new file mode 100644 index 00000000..9b5192a0 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -0,0 +1,107 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const profilesSchema = ` +-- Stores data about accounts profiles. +CREATE TABLE IF NOT EXISTS account_profiles ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- The display name for this account + display_name TEXT, + -- The URL of the avatar for this account + avatar_url TEXT +); +` + +const insertProfileSQL = "" + + "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + +const selectProfileByLocalpartSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1" + +const setAvatarURLSQL = "" + + "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2" + +const setDisplayNameSQL = "" + + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" + +type profilesStatements struct { + insertProfileStmt *sql.Stmt + selectProfileByLocalpartStmt *sql.Stmt + setAvatarURLStmt *sql.Stmt + setDisplayNameStmt *sql.Stmt +} + +func (s *profilesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(profilesSchema) + if err != nil { + return + } + if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { + return + } + if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { + return + } + if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { + return + } + if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { + return + } + return +} + +func (s *profilesStatements) insertProfile( + ctx context.Context, txn *sql.Tx, localpart string, +) (err error) { + _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + return +} + +func (s *profilesStatements) selectProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + var profile authtypes.Profile + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( + &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + ) + if err != nil { + return nil, err + } + return &profile, nil +} + +func (s *profilesStatements) setAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) (err error) { + _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) + return +} + +func (s *profilesStatements) setDisplayName( + ctx context.Context, localpart string, displayName string, +) (err error) { + _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) + return +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go new file mode 100644 index 00000000..4dd755a7 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -0,0 +1,444 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "strconv" + "sync" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + // Import the sqlite3 database driver. +) + +// Database represents an account database +type Database struct { + db *sql.DB + sqlutil.PartitionOffsetStatements + accounts accountsStatements + profiles profilesStatements + memberships membershipStatements + accountDatas accountDataStatements + threepids threepidStatements + filter filterStatements + serverName gomatrixserverlib.ServerName + + createGuestAccountMu sync.Mutex +} + +// NewDatabase creates a new accounts and profiles database +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { + return nil, err + } + partitions := sqlutil.PartitionOffsetStatements{} + if err = partitions.Prepare(db, "account"); err != nil { + return nil, err + } + a := accountsStatements{} + if err = a.prepare(db, serverName); err != nil { + return nil, err + } + p := profilesStatements{} + if err = p.prepare(db); err != nil { + return nil, err + } + m := membershipStatements{} + if err = m.prepare(db); err != nil { + return nil, err + } + ac := accountDataStatements{} + if err = ac.prepare(db); err != nil { + return nil, err + } + t := threepidStatements{} + if err = t.prepare(db); err != nil { + return nil, err + } + f := filterStatements{} + if err = f.prepare(db); err != nil { + return nil, err + } + return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil +} + +// GetAccountByPassword returns the account associated with the given localpart and password. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByPassword( + ctx context.Context, localpart, plaintextPassword string, +) (*api.Account, error) { + hash, err := d.accounts.selectPasswordHash(ctx, localpart) + if err != nil { + return nil, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { + return nil, err + } + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} + +// GetProfileByLocalpart returns the profile associated with the given localpart. +// Returns sql.ErrNoRows if no profile exists which matches the given localpart. +func (d *Database) GetProfileByLocalpart( + ctx context.Context, localpart string, +) (*authtypes.Profile, error) { + return d.profiles.selectProfileByLocalpart(ctx, localpart) +} + +// SetAvatarURL updates the avatar URL of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetAvatarURL( + ctx context.Context, localpart string, avatarURL string, +) error { + return d.profiles.setAvatarURL(ctx, localpart, avatarURL) +} + +// SetDisplayName updates the display name of the profile associated with the given +// localpart. Returns an error if something went wrong with the SQL query +func (d *Database) SetDisplayName( + ctx context.Context, localpart string, displayName string, +) error { + return d.profiles.setDisplayName(ctx, localpart, displayName) +} + +// CreateGuestAccount makes a new guest account and creates an empty profile +// for this account. +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + // We need to lock so we sequentially create numeric localparts. If we don't, two calls to + // this function will cause the same number to be selected and one will fail with 'database is locked' + // when the first txn upgrades to a write txn. + // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. + d.createGuestAccountMu.Lock() + defer d.createGuestAccountMu.Unlock() + + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart := strconv.FormatInt(numLocalpart, 10) + acc, err = d.createAccount(ctx, txn, localpart, "", "") + return err + }) + return acc, err +} + +// CreateAccount makes a new account with the given login name and password, and creates an empty profile +// for this account. If no password is supplied, the account will be a passwordless account. If the +// account already exists, it will return nil, ErrUserExists. +func (d *Database) CreateAccount( + ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (acc *api.Account, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + return err + }) + return +} + +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, +) (*api.Account, error) { + var err error + // Generate a password hash if this is not a password-less user + hash := "" + if plaintextPassword != "" { + hash, err = hashPassword(plaintextPassword) + if err != nil { + return nil, err + } + } + if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { + if isConstraintError(err) { + return nil, sqlutil.ErrUserExists + } + return nil, err + } + + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + "global": { + "content": [], + "override": [], + "room": [], + "sender": [], + "underride": [] + } + }`); err != nil { + return nil, err + } + return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) +} + +// SaveMembership saves the user matching a given localpart as a member of a given +// room. It also stores the ID of the membership event. +// If a membership already exists between the user and the room, or if the +// insert fails, returns the SQL error +func (d *Database) saveMembership( + ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string, +) error { + return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID) +} + +// removeMembershipsByEventIDs removes the memberships corresponding to the +// `join` membership events IDs in the eventIDs slice. +// If the removal fails, or if there is no membership to remove, returns an error +func (d *Database) removeMembershipsByEventIDs( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) error { + return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs) +} + +// UpdateMemberships adds the "join" membership events included in a given state +// events array, and removes those which ID is included in a given array of events +// IDs. All of the process is run in a transaction, which commits only once/if every +// insertion and deletion has been successfully processed. +// Returns a SQL error if there was an issue with any part of the process +func (d *Database) UpdateMemberships( + ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil { + return err + } + + for _, event := range eventsToAdd { + if err := d.newMembership(ctx, txn, event); err != nil { + return err + } + } + + return nil + }) +} + +// GetMembershipInRoomByLocalpart returns the membership for an user +// matching the given localpart if he is a member of the room matching roomID, +// if not sql.ErrNoRows is returned. +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipInRoomByLocalpart( + ctx context.Context, localpart, roomID string, +) (authtypes.Membership, error) { + return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) +} + +// GetMembershipsByLocalpart returns an array containing the memberships for all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetMembershipsByLocalpart( + ctx context.Context, localpart string, +) (memberships []authtypes.Membership, err error) { + return d.memberships.selectMembershipsByLocalpart(ctx, localpart) +} + +// GetRoomIDsByLocalPart returns an array containing the room ids of all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetRoomIDsByLocalPart( + ctx context.Context, localpart string, +) ([]string, error) { + return d.memberships.selectRoomIDsByLocalPart(ctx, localpart) +} + +// newMembership saves a new membership in the database. +// If the event isn't a valid m.room.member event with type `join`, does nothing. +// If an error occurred, returns the SQL error +func (d *Database) newMembership( + ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event, +) error { + if ev.Type() == "m.room.member" && ev.StateKey() != nil { + localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) + if err != nil { + return err + } + + // We only want state events from local users + if string(serverName) != string(d.serverName) { + return nil + } + + eventID := ev.EventID() + roomID := ev.RoomID() + membership, err := ev.Membership() + if err != nil { + return err + } + + // Only "join" membership events can be considered as new memberships + if membership == gomatrixserverlib.Join { + if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil { + return err + } + } + } + return nil +} + +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData( + ctx context.Context, localpart, roomID, dataType, content string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(ctx context.Context, localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + return d.accountDatas.selectAccountData(ctx, localpart) +} + +// GetAccountDataByType returns account data matching a given +// localpart, room ID and type. +// If no account data could be found, returns nil +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountDataByType( + ctx context.Context, localpart, roomID, dataType string, +) (data *gomatrixserverlib.ClientEvent, err error) { + return d.accountDatas.selectAccountDataByType( + ctx, localpart, roomID, dataType, + ) +} + +// GetNewNumericLocalpart generates and returns a new unused numeric localpart +func (d *Database) GetNewNumericLocalpart( + ctx context.Context, +) (int64, error) { + return d.accounts.selectNewNumericLocalpart(ctx, nil) +} + +func hashPassword(plaintext string) (hash string, err error) { + hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) + return string(hashBytes), err +} + +// Err3PIDInUse is the error returned when trying to save an association involving +// a third-party identifier which is already associated to a local user. +var Err3PIDInUse = errors.New("This third-party identifier is already in use") + +// SaveThreePIDAssociation saves the association between a third party identifier +// and a local Matrix user (identified by the user's ID's local part). +// If the third-party identifier is already part of an association, returns Err3PIDInUse. +// Returns an error if there was a problem talking to the database. +func (d *Database) SaveThreePIDAssociation( + ctx context.Context, threepid, localpart, medium string, +) (err error) { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + user, err := d.threepids.selectLocalpartForThreePID( + ctx, txn, threepid, medium, + ) + if err != nil { + return err + } + + if len(user) > 0 { + return Err3PIDInUse + } + + return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart) + }) +} + +// RemoveThreePIDAssociation removes the association involving a given third-party +// identifier. +// If no association exists involving this third-party identifier, returns nothing. +// If there was a problem talking to the database, returns an error. +func (d *Database) RemoveThreePIDAssociation( + ctx context.Context, threepid string, medium string, +) (err error) { + return d.threepids.deleteThreePID(ctx, threepid, medium) +} + +// GetLocalpartForThreePID looks up the localpart associated with a given third-party +// identifier. +// If no association involves the given third-party idenfitier, returns an empty +// string. +// Returns an error if there was a problem talking to the database. +func (d *Database) GetLocalpartForThreePID( + ctx context.Context, threepid string, medium string, +) (localpart string, err error) { + return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium) +} + +// GetThreePIDsForLocalpart looks up the third-party identifiers associated with +// a given local user. +// If no association is known for this user, returns an empty slice. +// Returns an error if there was an issue talking to the database. +func (d *Database) GetThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + return d.threepids.selectThreePIDsForLocalpart(ctx, localpart) +} + +// GetFilter looks up the filter associated with a given local user and filter ID. +// Returns a filter structure. Otherwise returns an error if no such filter exists +// or if there was an error talking to the database. +func (d *Database) GetFilter( + ctx context.Context, localpart string, filterID string, +) (*gomatrixserverlib.Filter, error) { + return d.filter.selectFilter(ctx, localpart, filterID) +} + +// PutFilter puts the passed filter into the database. +// Returns the filterID as a string. Otherwise returns an error if something +// goes wrong. +func (d *Database) PutFilter( + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, +) (string, error) { + return d.filter.insertFilter(ctx, filter, localpart) +} + +// CheckAccountAvailability checks if the username/localpart is already present +// in the database. +// If the DB returns sql.ErrNoRows the Localpart isn't taken. +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { + _, err := d.accounts.selectAccountByLocalpart(ctx, localpart) + if err == sql.ErrNoRows { + return true, nil + } + return false, err +} + +// GetAccountByLocalpart returns the account associated with the given localpart. +// This function assumes the request is authenticated or the account data is used only internally. +// Returns sql.ErrNoRows if no account exists which matches the given localpart. +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +) (*api.Account, error) { + return d.accounts.selectAccountByLocalpart(ctx, localpart) +} diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go new file mode 100644 index 00000000..0200dee7 --- /dev/null +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -0,0 +1,130 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" +) + +const threepidSchema = ` +-- Stores data about third party identifiers +CREATE TABLE IF NOT EXISTS account_threepid ( + -- The third party identifier + threepid TEXT NOT NULL, + -- The 3PID medium + medium TEXT NOT NULL DEFAULT 'email', + -- The localpart of the Matrix user ID associated to this 3PID + localpart TEXT NOT NULL, + + PRIMARY KEY(threepid, medium) +); + +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart); +` + +const selectLocalpartForThreePIDSQL = "" + + "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2" + +const selectThreePIDsForLocalpartSQL = "" + + "SELECT threepid, medium FROM account_threepid WHERE localpart = $1" + +const insertThreePIDSQL = "" + + "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)" + +const deleteThreePIDSQL = "" + + "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2" + +type threepidStatements struct { + selectLocalpartForThreePIDStmt *sql.Stmt + selectThreePIDsForLocalpartStmt *sql.Stmt + insertThreePIDStmt *sql.Stmt + deleteThreePIDStmt *sql.Stmt +} + +func (s *threepidStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(threepidSchema) + if err != nil { + return + } + if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { + return + } + if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { + return + } + if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { + return + } + if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { + return + } + + return +} + +func (s *threepidStatements) selectLocalpartForThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string, +) (localpart string, err error) { + stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + if err == sql.ErrNoRows { + return "", nil + } + return +} + +func (s *threepidStatements) selectThreePIDsForLocalpart( + ctx context.Context, localpart string, +) (threepids []authtypes.ThreePID, err error) { + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + if err != nil { + return + } + defer internal.CloseAndLogIfError(ctx, rows, "selectThreePIDsForLocalpart: rows.close() failed") + + threepids = []authtypes.ThreePID{} + for rows.Next() { + var threepid string + var medium string + if err = rows.Scan(&threepid, &medium); err != nil { + return + } + threepids = append(threepids, authtypes.ThreePID{ + Address: threepid, + Medium: medium, + }) + } + return threepids, rows.Err() +} + +func (s *threepidStatements) insertThreePID( + ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, +) (err error) { + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + return +} + +func (s *threepidStatements) deleteThreePID( + ctx context.Context, threepid string, medium string) (err error) { + _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) + return +} diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/accounts/storage.go new file mode 100644 index 00000000..87f626bf --- /dev/null +++ b/userapi/storage/accounts/storage.go @@ -0,0 +1,43 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !wasm + +package accounts + +import ( + "net/url" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres" + "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// and sets postgres connection parameters +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) + } + switch uri.Scheme { + case "postgres": + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) + case "file": + return sqlite3.NewDatabase(dataSourceName, serverName) + default: + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) + } +} diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/accounts/storage_wasm.go new file mode 100644 index 00000000..69256705 --- /dev/null +++ b/userapi/storage/accounts/storage_wasm.go @@ -0,0 +1,43 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package accounts + +import ( + "fmt" + "net/url" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3" + "github.com/matrix-org/gomatrixserverlib" +) + +func NewDatabase( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam + serverName gomatrixserverlib.ServerName, +) (Database, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { + return nil, fmt.Errorf("Cannot use postgres implementation") + } + switch uri.Scheme { + case "postgres": + return nil, fmt.Errorf("Cannot use postgres implementation") + case "file": + return sqlite3.NewDatabase(dataSourceName, serverName) + default: + return nil, fmt.Errorf("Cannot use postgres implementation") + } +} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go new file mode 100644 index 00000000..4bdb5785 --- /dev/null +++ b/userapi/storage/devices/interface.go @@ -0,0 +1,38 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devices + +import ( + "context" + + "github.com/matrix-org/dendrite/userapi/api" +) + +type Database interface { + GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) + GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + // CreateDevice makes a new device associated with the given user ID localpart. + // If there is already a device with the same device ID for this user, that access token will be revoked + // and replaced with the given accessToken. If the given accessToken is already in use for another device, + // an error will be returned. + // If no device ID is given one is generated. + // Returns the device on success. + CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error + RemoveDevice(ctx context.Context, deviceID, localpart string) error + RemoveDevices(ctx context.Context, localpart string, devices []string) error + RemoveAllDevices(ctx context.Context, localpart string) error +} diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go new file mode 100644 index 00000000..1d036d1b --- /dev/null +++ b/userapi/storage/devices/postgres/devices_table.go @@ -0,0 +1,249 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +const devicesSchema = ` +-- This sequence is used for automatic allocation of session_id. +CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; + +-- Stores data about devices. +CREATE TABLE IF NOT EXISTS device_devices ( + -- The access token granted to this device. This has to be the primary key + -- so we can distinguish which device is making a given request. + access_token TEXT NOT NULL PRIMARY KEY, + -- The auto-allocated unique ID of the session identified by the access token. + -- This can be used as a secure substitution of the access token in situations + -- where data is associated with access tokens (e.g. transaction storage), + -- so we don't have to store users' access tokens everywhere. + session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'), + -- The device identifier. This only needs to uniquely identify a device for a given user, not globally. + -- access_tokens will be clobbered based on the device ID for a user. + device_id TEXT NOT NULL, + -- The Matrix user ID localpart for this device. This is preferable to storing the full user_id + -- as it is smaller, makes it clearer that we only manage devices for our own users, and may make + -- migration to different domain names easier. + localpart TEXT NOT NULL, + -- When this devices was first recognised on the network, as a unix timestamp (ms resolution). + created_ts BIGINT NOT NULL, + -- The display name, human friendlier than device_id and updatable + display_name TEXT + -- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app) +); + +-- Device IDs must be unique for a given user. +CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id); +` + +const insertDeviceSQL = "" + + "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" + + " RETURNING session_id" + +const selectDeviceByTokenSQL = "" + + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + +const selectDeviceByIDSQL = "" + + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + +const selectDevicesByLocalpartSQL = "" + + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + +const updateDeviceNameSQL = "" + + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + +const deleteDeviceSQL = "" + + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + +const deleteDevicesByLocalpartSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1" + +const deleteDevicesSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" + +type devicesStatements struct { + insertDeviceStmt *sql.Stmt + selectDeviceByTokenStmt *sql.Stmt + selectDeviceByIDStmt *sql.Stmt + selectDevicesByLocalpartStmt *sql.Stmt + updateDeviceNameStmt *sql.Stmt + deleteDeviceStmt *sql.Stmt + deleteDevicesByLocalpartStmt *sql.Stmt + deleteDevicesStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(devicesSchema) + if err != nil { + return + } + if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { + return + } + if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { + return + } + if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { + return + } + if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { + return + } + if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { + return + } + if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { + return + } + if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { + return + } + if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertDevice creates a new device. Returns an error if any device with the same access token already exists. +// Returns an error if the user already has a device with the given device ID. +// Returns the device on success. +func (s *devicesStatements) insertDevice( + ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, + displayName *string, +) (*api.Device, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + var sessionID int64 + stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil { + return nil, err + } + return &api.Device{ + ID: id, + UserID: userutil.MakeUserID(localpart, s.serverName), + AccessToken: accessToken, + SessionID: sessionID, + }, nil +} + +// deleteDevice removes a single device by id and user localpart. +func (s *devicesStatements) deleteDevice( + ctx context.Context, txn *sql.Tx, id, localpart string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err +} + +// deleteDevices removes a single or multiple devices by ids and user localpart. +// Returns an error if the execution failed. +func (s *devicesStatements) deleteDevices( + ctx context.Context, txn *sql.Tx, localpart string, devices []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) + _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + return err +} + +// deleteDevicesByLocalpart removes all devices for the +// given user localpart. +func (s *devicesStatements) deleteDevicesByLocalpart( + ctx context.Context, txn *sql.Tx, localpart string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err +} + +func (s *devicesStatements) updateDeviceName( + ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, +) error { + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err +} + +func (s *devicesStatements) selectDeviceByToken( + ctx context.Context, accessToken string, +) (*api.Device, error) { + var dev api.Device + var localpart string + stmt := s.selectDeviceByTokenStmt + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + if err == nil { + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.AccessToken = accessToken + } + return &dev, err +} + +// selectDeviceByID retrieves a device from the database with the given user +// localpart and deviceID +func (s *devicesStatements) selectDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + var dev api.Device + stmt := s.selectDeviceByIDStmt + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + if err == nil { + dev.ID = deviceID + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + } + return &dev, err +} + +func (s *devicesStatements) selectDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + devices := []api.Device{} + + rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + + if err != nil { + return devices, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed") + + for rows.Next() { + var dev api.Device + var id, displayname sql.NullString + err = rows.Scan(&id, &displayname) + if err != nil { + return devices, err + } + if id.Valid { + dev.ID = id.String + } + if displayname.Valid { + dev.DisplayName = displayname.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + + return devices, rows.Err() +} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go new file mode 100644 index 00000000..801657bd --- /dev/null +++ b/userapi/storage/devices/postgres/storage.go @@ -0,0 +1,182 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +// The length of generated device IDs +var deviceIDByteLength = 6 + +// Database represents a device database. +type Database struct { + db *sql.DB + devices devicesStatements +} + +// NewDatabase creates a new device database +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil { + return nil, err + } + d := devicesStatements{} + if err = d.prepare(db, serverName); err != nil { + return nil, err + } + return &Database{db, d}, nil +} + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, localpart) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go new file mode 100644 index 00000000..07ea5dca --- /dev/null +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -0,0 +1,249 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const devicesSchema = ` +-- This sequence is used for automatic allocation of session_id. +-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; + +-- Stores data about devices. +CREATE TABLE IF NOT EXISTS device_devices ( + access_token TEXT PRIMARY KEY, + session_id INTEGER, + device_id TEXT , + localpart TEXT , + created_ts BIGINT, + display_name TEXT, + + UNIQUE (localpart, device_id) +); +` + +const insertDeviceSQL = "" + + "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + +const selectDevicesCountSQL = "" + + "SELECT COUNT(access_token) FROM device_devices" + +const selectDeviceByTokenSQL = "" + + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" + +const selectDeviceByIDSQL = "" + + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" + +const selectDevicesByLocalpartSQL = "" + + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + +const updateDeviceNameSQL = "" + + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + +const deleteDeviceSQL = "" + + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" + +const deleteDevicesByLocalpartSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1" + +const deleteDevicesSQL = "" + + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" + +type devicesStatements struct { + db *sql.DB + insertDeviceStmt *sql.Stmt + selectDevicesCountStmt *sql.Stmt + selectDeviceByTokenStmt *sql.Stmt + selectDeviceByIDStmt *sql.Stmt + selectDevicesByLocalpartStmt *sql.Stmt + updateDeviceNameStmt *sql.Stmt + deleteDeviceStmt *sql.Stmt + deleteDevicesByLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + _, err = db.Exec(devicesSchema) + if err != nil { + return + } + if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { + return + } + if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { + return + } + if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { + return + } + if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { + return + } + if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil { + return + } + if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil { + return + } + if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil { + return + } + if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertDevice creates a new device. Returns an error if any device with the same access token already exists. +// Returns an error if the user already has a device with the given device ID. +// Returns the device on success. +func (s *devicesStatements) insertDevice( + ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, + displayName *string, +) (*api.Device, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + var sessionID int64 + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return nil, err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { + return nil, err + } + return &api.Device{ + ID: id, + UserID: userutil.MakeUserID(localpart, s.serverName), + AccessToken: accessToken, + SessionID: sessionID, + }, nil +} + +func (s *devicesStatements) deleteDevice( + ctx context.Context, txn *sql.Tx, id, localpart string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err +} + +func (s *devicesStatements) deleteDevices( + ctx context.Context, txn *sql.Tx, localpart string, devices []string, +) error { + orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1) + prep, err := s.db.Prepare(orig) + if err != nil { + return err + } + stmt := sqlutil.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + params = append(params, params...) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *devicesStatements) deleteDevicesByLocalpart( + ctx context.Context, txn *sql.Tx, localpart string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err +} + +func (s *devicesStatements) updateDeviceName( + ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, +) error { + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err +} + +func (s *devicesStatements) selectDeviceByToken( + ctx context.Context, accessToken string, +) (*api.Device, error) { + var dev api.Device + var localpart string + stmt := s.selectDeviceByTokenStmt + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + if err == nil { + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.AccessToken = accessToken + } + return &dev, err +} + +// selectDeviceByID retrieves a device from the database with the given user +// localpart and deviceID +func (s *devicesStatements) selectDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + var dev api.Device + stmt := s.selectDeviceByIDStmt + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + if err == nil { + dev.ID = deviceID + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + } + return &dev, err +} + +func (s *devicesStatements) selectDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + devices := []api.Device{} + + rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + + if err != nil { + return devices, err + } + + for rows.Next() { + var dev api.Device + var id, displayname sql.NullString + err = rows.Scan(&id, &displayname) + if err != nil { + return devices, err + } + if id.Valid { + dev.ID = id.String + } + if displayname.Valid { + dev.DisplayName = displayname.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + + return devices, nil +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go new file mode 100644 index 00000000..f248abda --- /dev/null +++ b/userapi/storage/devices/sqlite3/storage.go @@ -0,0 +1,188 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + + _ "github.com/mattn/go-sqlite3" +) + +// The length of generated device IDs +var deviceIDByteLength = 6 + +// Database represents a device database. +type Database struct { + db *sql.DB + devices devicesStatements +} + +// NewDatabase creates a new device database +func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) { + var db *sql.DB + var err error + cs, err := sqlutil.ParseFileURI(dataSourceName) + if err != nil { + return nil, err + } + if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil { + return nil, err + } + d := devicesStatements{} + if err = d.prepare(db, serverName); err != nil { + return nil, err + } + return &Database{db, d}, nil +} + +// GetDeviceByAccessToken returns the device matching the given access token. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByAccessToken( + ctx context.Context, token string, +) (*api.Device, error) { + return d.devices.selectDeviceByToken(ctx, token) +} + +// GetDeviceByID returns the device matching the given ID. +// Returns sql.ErrNoRows if no matching device was found. +func (d *Database) GetDeviceByID( + ctx context.Context, localpart, deviceID string, +) (*api.Device, error) { + return d.devices.selectDeviceByID(ctx, localpart, deviceID) +} + +// GetDevicesByLocalpart returns the devices matching the given localpart. +func (d *Database) GetDevicesByLocalpart( + ctx context.Context, localpart string, +) ([]api.Device, error) { + return d.devices.selectDevicesByLocalpart(ctx, localpart) +} + +// CreateDevice makes a new device associated with the given user ID localpart. +// If there is already a device with the same device ID for this user, that access token will be revoked +// and replaced with the given accessToken. If the given accessToken is already in use for another device, +// an error will be returned. +// If no device ID is given one is generated. +// Returns the device on success. +func (d *Database) CreateDevice( + ctx context.Context, localpart string, deviceID *string, accessToken string, + displayName *string, +) (dev *api.Device, returnErr error) { + if deviceID != nil { + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { + return err + } + + dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName) + return err + }) + } else { + // We generate device IDs in a loop in case its already taken. + // We cap this at going round 5 times to ensure we don't spin forever + var newDeviceID string + for i := 1; i <= 5; i++ { + newDeviceID, returnErr = generateDeviceID() + if returnErr != nil { + return + } + + returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + var err error + dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) + return err + }) + if returnErr == nil { + return + } + } + } + return +} + +// generateDeviceID creates a new device id. Returns an error if failed to generate +// random bytes. +func generateDeviceID() (string, error) { + b := make([]byte, deviceIDByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + // url-safe no padding + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// UpdateDevice updates the given device with the display name. +// Returns SQL error if there are problems and nil on success. +func (d *Database) UpdateDevice( + ctx context.Context, localpart, deviceID string, displayName *string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) + }) +} + +// RemoveDevice revokes a device by deleting the entry in the database +// matching with the given device ID and user ID localpart. +// If the device doesn't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevice( + ctx context.Context, deviceID, localpart string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveDevices revokes one or more devices by deleting the entry in the database +// matching with the given device IDs and user ID localpart. +// If the devices don't exist, it will not return an error +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveDevices( + ctx context.Context, localpart string, devices []string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + return err + } + return nil + }) +} + +// RemoveAllDevices revokes devices by deleting the entry in the +// database matching the given user ID localpart. +// If something went wrong during the deletion, it will return the SQL error. +func (d *Database) RemoveAllDevices( + ctx context.Context, localpart string, +) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + return err + } + return nil + }) +} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go new file mode 100644 index 00000000..e094d202 --- /dev/null +++ b/userapi/storage/devices/storage.go @@ -0,0 +1,43 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !wasm + +package devices + +import ( + "net/url" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" + "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// and sets postgres connection parameters +func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) + } + switch uri.Scheme { + case "postgres": + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) + case "file": + return sqlite3.NewDatabase(dataSourceName, serverName) + default: + return postgres.NewDatabase(dataSourceName, dbProperties, serverName) + } +} diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go new file mode 100644 index 00000000..a5a515ef --- /dev/null +++ b/userapi/storage/devices/storage_wasm.go @@ -0,0 +1,43 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devices + +import ( + "fmt" + "net/url" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" + "github.com/matrix-org/gomatrixserverlib" +) + +func NewDatabase( + dataSourceName string, + dbProperties sqlutil.DbProperties, // nolint:unparam + serverName gomatrixserverlib.ServerName, +) (Database, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { + return nil, fmt.Errorf("Cannot use postgres implementation") + } + switch uri.Scheme { + case "postgres": + return nil, fmt.Errorf("Cannot use postgres implementation") + case "file": + return sqlite3.NewDatabase(dataSourceName, serverName) + default: + return nil, fmt.Errorf("Cannot use postgres implementation") + } +} diff --git a/userapi/userapi.go b/userapi/userapi.go index 961d04fe..7aadec06 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -16,12 +16,12 @@ package userapi import ( "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 509bdd7e..163b10ec 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -8,13 +8,13 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) |