aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-02-13 17:27:33 +0000
committerGitHub <noreply@github.com>2020-02-13 17:27:33 +0000
commitb6ea1bc67ab51667b9e139dd05e0778aca025501 (patch)
tree18569c317fd28544144c320ce844d93a8ff8ec5e
parent6942ee1de0250235164cf0ce45570b7fc919669d (diff)
Support sqlite in addition to postgres (#869)
* Move current work into single branch * Initial massaging of clientapi etc (not working yet) * Interfaces for accounts/devices databases * Duplicate postgres package for sqlite3 (no changes made to it yet) * Some keydb, accountdb, devicedb, common partition fixes, some more syncapi tweaking * Fix accounts DB, device DB * Update naffka dependency for SQLite * Naffka SQLite * Update naffka to latest master * SQLite support for federationsender * Mostly not-bad support for SQLite in syncapi (although there are problems where lots of events get classed incorrectly as backward extremities, probably because of IN/ANY clauses that are badly supported) * Update Dockerfile -> Go 1.13.7, add build-base (as gcc and friends are needed for SQLite) * Implement GET endpoints for account_data in clientapi * Nuke filtering for now... * Revert "Implement GET endpoints for account_data in clientapi" This reverts commit 4d80dff4583d278620d9b3ed437e9fcd8d4674ee. * Implement GET endpoints for account_data in clientapi (#861) * Implement GET endpoints for account_data in clientapi * Fix accountDB parameter * Remove fmt.Println * Fix insertAccountData SQLite query * Fix accountDB storage interfaces * Add empty push rules into account data on account creation (#862) * Put SaveAccountData into the right function this time * Not sure if roomserver is better or worse now * sqlite work * Allow empty last sent ID for the first event * sqlite: room creation works * Support sending messages * Nuke fmt.println * Move QueryVariadic etc into common, other device fixes * Fix some linter issues * Fix bugs * Fix some linting errors * Fix errcheck lint errors * Make naffka use postgres as fallback, fix couple of compile errors * What on earth happened to the /rooms/{roomID}/send/{eventType} routing Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
-rw-r--r--appservice/api/query.go2
-rw-r--r--appservice/appservice.go8
-rw-r--r--appservice/consumers/roomserver.go4
-rw-r--r--appservice/routing/routing.go2
-rw-r--r--clientapi/auth/storage/accounts/postgres/account_data_table.go (renamed from clientapi/auth/storage/accounts/account_data_table.go)2
-rw-r--r--clientapi/auth/storage/accounts/postgres/accounts_table.go (renamed from clientapi/auth/storage/accounts/accounts_table.go)2
-rw-r--r--clientapi/auth/storage/accounts/postgres/filter_table.go (renamed from clientapi/auth/storage/accounts/filter_table.go)2
-rw-r--r--clientapi/auth/storage/accounts/postgres/membership_table.go (renamed from clientapi/auth/storage/accounts/membership_table.go)2
-rw-r--r--clientapi/auth/storage/accounts/postgres/profile_table.go (renamed from clientapi/auth/storage/accounts/profile_table.go)2
-rw-r--r--clientapi/auth/storage/accounts/postgres/storage.go392
-rw-r--r--clientapi/auth/storage/accounts/postgres/threepid_table.go129
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/account_data_table.go141
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/accounts_table.go151
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/filter_table.go139
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/membership_table.go131
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/profile_table.go107
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/storage.go392
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/threepid_table.go (renamed from clientapi/auth/storage/accounts/threepid_table.go)2
-rw-r--r--clientapi/auth/storage/accounts/storage.go410
-rw-r--r--clientapi/auth/storage/devices/postgres/devices_table.go (renamed from clientapi/auth/storage/devices/devices_table.go)2
-rw-r--r--clientapi/auth/storage/devices/postgres/storage.go182
-rw-r--r--clientapi/auth/storage/devices/sqlite3/devices_table.go243
-rw-r--r--clientapi/auth/storage/devices/sqlite3/storage.go184
-rw-r--r--clientapi/auth/storage/devices/storage.go191
-rw-r--r--clientapi/clientapi.go4
-rw-r--r--clientapi/consumers/roomserver.go4
-rw-r--r--clientapi/routing/account_data.go4
-rw-r--r--clientapi/routing/createroom.go4
-rw-r--r--clientapi/routing/device.go10
-rw-r--r--clientapi/routing/filter.go4
-rw-r--r--clientapi/routing/joinroom.go2
-rw-r--r--clientapi/routing/login.go4
-rw-r--r--clientapi/routing/logout.go4
-rw-r--r--clientapi/routing/membership.go8
-rw-r--r--clientapi/routing/profile.go12
-rw-r--r--clientapi/routing/register.go30
-rw-r--r--clientapi/routing/room_tagging.go10
-rw-r--r--clientapi/routing/routing.go4
-rw-r--r--clientapi/routing/sendtyping.go2
-rw-r--r--clientapi/routing/threepid.go8
-rw-r--r--clientapi/threepid/invites.go6
-rw-r--r--cmd/kafka-producer/main.go2
-rw-r--r--common/basecomponent/base.go64
-rw-r--r--common/keydb/keydb.go3
-rw-r--r--common/keydb/sqlite3/keydb.go115
-rw-r--r--common/keydb/sqlite3/server_key_table.go142
-rw-r--r--common/partition_offset_table.go4
-rw-r--r--common/sql.go33
-rw-r--r--docker/Dockerfile6
-rw-r--r--docker/docker-compose.yml10
-rw-r--r--federationapi/federationapi.go4
-rw-r--r--federationapi/routing/devices.go2
-rw-r--r--federationapi/routing/profile.go2
-rw-r--r--federationapi/routing/routing.go4
-rw-r--r--federationapi/routing/threepid.go4
-rw-r--r--federationsender/storage/postgres/storage.go2
-rw-r--r--federationsender/storage/sqlite3/joined_hosts_table.go139
-rw-r--r--federationsender/storage/sqlite3/room_table.go101
-rw-r--r--federationsender/storage/sqlite3/storage.go124
-rw-r--r--federationsender/storage/storage.go3
-rw-r--r--go.mod20
-rw-r--r--go.sum42
-rw-r--r--mediaapi/mediaapi.go2
-rw-r--r--mediaapi/routing/routing.go2
-rw-r--r--publicroomsapi/publicroomsapi.go2
-rw-r--r--publicroomsapi/routing/routing.go2
-rw-r--r--publicroomsapi/storage/sqlite3/prepare.go36
-rw-r--r--publicroomsapi/storage/sqlite3/public_rooms_table.go277
-rw-r--r--publicroomsapi/storage/sqlite3/storage.go256
-rw-r--r--roomserver/input/events.go7
-rw-r--r--roomserver/input/latest_events.go7
-rw-r--r--roomserver/storage/sqlite3/event_json_table.go108
-rw-r--r--roomserver/storage/sqlite3/event_state_keys_table.go156
-rw-r--r--roomserver/storage/sqlite3/event_types_table.go153
-rw-r--r--roomserver/storage/sqlite3/events_table.go479
-rw-r--r--roomserver/storage/sqlite3/invite_table.go142
-rw-r--r--roomserver/storage/sqlite3/list.go18
-rw-r--r--roomserver/storage/sqlite3/membership_table.go180
-rw-r--r--roomserver/storage/sqlite3/prepare.go36
-rw-r--r--roomserver/storage/sqlite3/previous_events_table.go92
-rw-r--r--roomserver/storage/sqlite3/room_aliases_table.go135
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go165
-rw-r--r--roomserver/storage/sqlite3/sql.go60
-rw-r--r--roomserver/storage/sqlite3/state_block_table.go292
-rw-r--r--roomserver/storage/sqlite3/state_block_table_test.go86
-rw-r--r--roomserver/storage/sqlite3/state_snapshot_table.go120
-rw-r--r--roomserver/storage/sqlite3/storage.go864
-rw-r--r--roomserver/storage/sqlite3/transactions_table.go86
-rw-r--r--roomserver/storage/storage.go14
-rw-r--r--syncapi/routing/routing.go2
-rw-r--r--syncapi/storage/postgres/syncserver.go17
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go143
-rw-r--r--syncapi/storage/sqlite3/backward_extremities_table.go124
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go276
-rw-r--r--syncapi/storage/sqlite3/filtering.go36
-rw-r--r--syncapi/storage/sqlite3/invites_table.go157
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go411
-rw-r--r--syncapi/storage/sqlite3/output_room_events_topology_table.go192
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go58
-rw-r--r--syncapi/storage/sqlite3/syncserver.go1197
-rw-r--r--syncapi/storage/storage.go3
-rw-r--r--syncapi/sync/requestpool.go4
-rw-r--r--syncapi/syncapi.go4
103 files changed, 9463 insertions, 706 deletions
diff --git a/appservice/api/query.go b/appservice/api/query.go
index 9542df56..7e61d623 100644
--- a/appservice/api/query.go
+++ b/appservice/api/query.go
@@ -140,7 +140,7 @@ func RetrieveUserProfile(
ctx context.Context,
userID string,
asAPI AppServiceQueryAPI,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
) (*authtypes.Profile, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
diff --git a/appservice/appservice.go b/appservice/appservice.go
index f2cbcce2..18179987 100644
--- a/appservice/appservice.go
+++ b/appservice/appservice.go
@@ -41,8 +41,8 @@ import (
// component.
func SetupAppServiceAPIComponent(
base *basecomponent.BaseDendrite,
- accountsDB *accounts.Database,
- deviceDB *devices.Database,
+ accountsDB accounts.Database,
+ deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient,
roomserverAliasAPI roomserverAPI.RoomserverAliasAPI,
roomserverQueryAPI roomserverAPI.RoomserverQueryAPI,
@@ -111,8 +111,8 @@ func SetupAppServiceAPIComponent(
// `sender_localpart` field of each application service if it doesn't
// exist already
func generateAppServiceAccount(
- accountsDB *accounts.Database,
- deviceDB *devices.Database,
+ accountsDB accounts.Database,
+ deviceDB devices.Database,
as config.ApplicationService,
) error {
ctx := context.Background()
diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go
index dbdae532..b9a56795 100644
--- a/appservice/consumers/roomserver.go
+++ b/appservice/consumers/roomserver.go
@@ -33,7 +33,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer
- db *accounts.Database
+ db accounts.Database
asDB *storage.Database
query api.RoomserverQueryAPI
alias api.RoomserverAliasAPI
@@ -46,7 +46,7 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
- store *accounts.Database,
+ store accounts.Database,
appserviceDB *storage.Database,
queryAPI api.RoomserverQueryAPI,
aliasAPI api.RoomserverAliasAPI,
diff --git a/appservice/routing/routing.go b/appservice/routing/routing.go
index 8a24caad..42fa8052 100644
--- a/appservice/routing/routing.go
+++ b/appservice/routing/routing.go
@@ -38,7 +38,7 @@ const pathPrefixApp = "/_matrix/app/v1"
func Setup(
apiMux *mux.Router, cfg *config.Dendrite, // nolint: unparam
queryAPI api.RoomserverQueryAPI, aliasAPI api.RoomserverAliasAPI, // nolint: unparam
- accountDB *accounts.Database, // nolint: unparam
+ accountDB accounts.Database, // nolint: unparam
federation *gomatrixserverlib.FederationClient, // nolint: unparam
transactionsCache *transactions.Cache, // nolint: unparam
) {
diff --git a/clientapi/auth/storage/accounts/account_data_table.go b/clientapi/auth/storage/accounts/postgres/account_data_table.go
index 1b7484d8..d0cfcc0c 100644
--- a/clientapi/auth/storage/accounts/account_data_table.go
+++ b/clientapi/auth/storage/accounts/postgres/account_data_table.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package postgres
import (
"context"
diff --git a/clientapi/auth/storage/accounts/accounts_table.go b/clientapi/auth/storage/accounts/postgres/accounts_table.go
index e86654ec..6b8ed372 100644
--- a/clientapi/auth/storage/accounts/accounts_table.go
+++ b/clientapi/auth/storage/accounts/postgres/accounts_table.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package postgres
import (
"context"
diff --git a/clientapi/auth/storage/accounts/filter_table.go b/clientapi/auth/storage/accounts/postgres/filter_table.go
index 2b07ef17..c54e4bc4 100644
--- a/clientapi/auth/storage/accounts/filter_table.go
+++ b/clientapi/auth/storage/accounts/postgres/filter_table.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package postgres
import (
"context"
diff --git a/clientapi/auth/storage/accounts/membership_table.go b/clientapi/auth/storage/accounts/postgres/membership_table.go
index 7b7c50ac..426c2d6a 100644
--- a/clientapi/auth/storage/accounts/membership_table.go
+++ b/clientapi/auth/storage/accounts/postgres/membership_table.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package postgres
import (
"context"
diff --git a/clientapi/auth/storage/accounts/profile_table.go b/clientapi/auth/storage/accounts/postgres/profile_table.go
index 157bb99b..38c76c40 100644
--- a/clientapi/auth/storage/accounts/profile_table.go
+++ b/clientapi/auth/storage/accounts/postgres/profile_table.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package postgres
import (
"context"
diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go
new file mode 100644
index 00000000..cb74d131
--- /dev/null
+++ b/clientapi/auth/storage/accounts/postgres/storage.go
@@ -0,0 +1,392 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/gomatrixserverlib"
+ "golang.org/x/crypto/bcrypt"
+
+ // Import the postgres database driver.
+ _ "github.com/lib/pq"
+)
+
+// Database represents an account database
+type Database struct {
+ db *sql.DB
+ common.PartitionOffsetStatements
+ accounts accountsStatements
+ profiles profilesStatements
+ memberships membershipStatements
+ accountDatas accountDataStatements
+ threepids threepidStatements
+ filter filterStatements
+ serverName gomatrixserverlib.ServerName
+}
+
+// NewDatabase creates a new accounts and profiles database
+func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
+ var db *sql.DB
+ var err error
+ if db, err = sql.Open("postgres", dataSourceName); err != nil {
+ return nil, err
+ }
+ partitions := common.PartitionOffsetStatements{}
+ if err = partitions.Prepare(db, "account"); err != nil {
+ return nil, err
+ }
+ a := accountsStatements{}
+ if err = a.prepare(db, serverName); err != nil {
+ return nil, err
+ }
+ p := profilesStatements{}
+ if err = p.prepare(db); err != nil {
+ return nil, err
+ }
+ m := membershipStatements{}
+ if err = m.prepare(db); err != nil {
+ return nil, err
+ }
+ ac := accountDataStatements{}
+ if err = ac.prepare(db); err != nil {
+ return nil, err
+ }
+ t := threepidStatements{}
+ if err = t.prepare(db); err != nil {
+ return nil, err
+ }
+ f := filterStatements{}
+ if err = f.prepare(db); err != nil {
+ return nil, err
+ }
+ return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
+}
+
+// GetAccountByPassword returns the account associated with the given localpart and password.
+// Returns sql.ErrNoRows if no account exists which matches the given localpart.
+func (d *Database) GetAccountByPassword(
+ ctx context.Context, localpart, plaintextPassword string,
+) (*authtypes.Account, error) {
+ hash, err := d.accounts.selectPasswordHash(ctx, localpart)
+ if err != nil {
+ return nil, err
+ }
+ if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
+ return nil, err
+ }
+ return d.accounts.selectAccountByLocalpart(ctx, localpart)
+}
+
+// GetProfileByLocalpart returns the profile associated with the given localpart.
+// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
+func (d *Database) GetProfileByLocalpart(
+ ctx context.Context, localpart string,
+) (*authtypes.Profile, error) {
+ return d.profiles.selectProfileByLocalpart(ctx, localpart)
+}
+
+// SetAvatarURL updates the avatar URL of the profile associated with the given
+// localpart. Returns an error if something went wrong with the SQL query
+func (d *Database) SetAvatarURL(
+ ctx context.Context, localpart string, avatarURL string,
+) error {
+ return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
+}
+
+// SetDisplayName updates the display name of the profile associated with the given
+// localpart. Returns an error if something went wrong with the SQL query
+func (d *Database) SetDisplayName(
+ ctx context.Context, localpart string, displayName string,
+) error {
+ return d.profiles.setDisplayName(ctx, localpart, displayName)
+}
+
+// CreateAccount makes a new account with the given login name and password, and creates an empty profile
+// for this account. If no password is supplied, the account will be a passwordless account. If the
+// account already exists, it will return nil, nil.
+func (d *Database) CreateAccount(
+ ctx context.Context, localpart, plaintextPassword, appserviceID string,
+) (*authtypes.Account, error) {
+ var err error
+
+ // Generate a password hash if this is not a password-less user
+ hash := ""
+ if plaintextPassword != "" {
+ hash, err = hashPassword(plaintextPassword)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if err := d.profiles.insertProfile(ctx, localpart); err != nil {
+ if common.IsUniqueConstraintViolationErr(err) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
+ "global": {
+ "content": [],
+ "override": [],
+ "room": [],
+ "sender": [],
+ "underride": []
+ }
+ }`); err != nil {
+ return nil, err
+ }
+ return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
+}
+
+// SaveMembership saves the user matching a given localpart as a member of a given
+// room. It also stores the ID of the membership event.
+// If a membership already exists between the user and the room, or if the
+// insert fails, returns the SQL error
+func (d *Database) saveMembership(
+ ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
+) error {
+ return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
+}
+
+// removeMembershipsByEventIDs removes the memberships corresponding to the
+// `join` membership events IDs in the eventIDs slice.
+// If the removal fails, or if there is no membership to remove, returns an error
+func (d *Database) removeMembershipsByEventIDs(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) error {
+ return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
+}
+
+// UpdateMemberships adds the "join" membership events included in a given state
+// events array, and removes those which ID is included in a given array of events
+// IDs. All of the process is run in a transaction, which commits only once/if every
+// insertion and deletion has been successfully processed.
+// Returns a SQL error if there was an issue with any part of the process
+func (d *Database) UpdateMemberships(
+ ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
+ return err
+ }
+
+ for _, event := range eventsToAdd {
+ if err := d.newMembership(ctx, txn, event); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+// GetMembershipInRoomByLocalpart returns the membership for an user
+// matching the given localpart if he is a member of the room matching roomID,
+// if not sql.ErrNoRows is returned.
+// If there was an issue during the retrieval, returns the SQL error
+func (d *Database) GetMembershipInRoomByLocalpart(
+ ctx context.Context, localpart, roomID string,
+) (authtypes.Membership, error) {
+ return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
+}
+
+// GetMembershipsByLocalpart returns an array containing the memberships for all
+// the rooms a user matching a given localpart is a member of
+// If no membership match the given localpart, returns an empty array
+// If there was an issue during the retrieval, returns the SQL error
+func (d *Database) GetMembershipsByLocalpart(
+ ctx context.Context, localpart string,
+) (memberships []authtypes.Membership, err error) {
+ return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
+}
+
+// newMembership saves a new membership in the database.
+// If the event isn't a valid m.room.member event with type `join`, does nothing.
+// If an error occurred, returns the SQL error
+func (d *Database) newMembership(
+ ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
+) error {
+ if ev.Type() == "m.room.member" && ev.StateKey() != nil {
+ localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
+ if err != nil {
+ return err
+ }
+
+ // We only want state events from local users
+ if string(serverName) != string(d.serverName) {
+ return nil
+ }
+
+ eventID := ev.EventID()
+ roomID := ev.RoomID()
+ membership, err := ev.Membership()
+ if err != nil {
+ return err
+ }
+
+ // Only "join" membership events can be considered as new memberships
+ if membership == gomatrixserverlib.Join {
+ if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// SaveAccountData saves new account data for a given user and a given room.
+// If the account data is not specific to a room, the room ID should be an empty string
+// If an account data already exists for a given set (user, room, data type), it will
+// update the corresponding row with the new content
+// Returns a SQL error if there was an issue with the insertion/update
+func (d *Database) SaveAccountData(
+ ctx context.Context, localpart, roomID, dataType, content string,
+) error {
+ return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
+}
+
+// GetAccountData returns account data related to a given localpart
+// If no account data could be found, returns an empty arrays
+// Returns an error if there was an issue with the retrieval
+func (d *Database) GetAccountData(ctx context.Context, localpart string) (
+ global []gomatrixserverlib.ClientEvent,
+ rooms map[string][]gomatrixserverlib.ClientEvent,
+ err error,
+) {
+ return d.accountDatas.selectAccountData(ctx, localpart)
+}
+
+// GetAccountDataByType returns account data matching a given
+// localpart, room ID and type.
+// If no account data could be found, returns nil
+// Returns an error if there was an issue with the retrieval
+func (d *Database) GetAccountDataByType(
+ ctx context.Context, localpart, roomID, dataType string,
+) (data *gomatrixserverlib.ClientEvent, err error) {
+ return d.accountDatas.selectAccountDataByType(
+ ctx, localpart, roomID, dataType,
+ )
+}
+
+// GetNewNumericLocalpart generates and returns a new unused numeric localpart
+func (d *Database) GetNewNumericLocalpart(
+ ctx context.Context,
+) (int64, error) {
+ return d.accounts.selectNewNumericLocalpart(ctx)
+}
+
+func hashPassword(plaintext string) (hash string, err error) {
+ hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
+ return string(hashBytes), err
+}
+
+// Err3PIDInUse is the error returned when trying to save an association involving
+// a third-party identifier which is already associated to a local user.
+var Err3PIDInUse = errors.New("This third-party identifier is already in use")
+
+// SaveThreePIDAssociation saves the association between a third party identifier
+// and a local Matrix user (identified by the user's ID's local part).
+// If the third-party identifier is already part of an association, returns Err3PIDInUse.
+// Returns an error if there was a problem talking to the database.
+func (d *Database) SaveThreePIDAssociation(
+ ctx context.Context, threepid, localpart, medium string,
+) (err error) {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ user, err := d.threepids.selectLocalpartForThreePID(
+ ctx, txn, threepid, medium,
+ )
+ if err != nil {
+ return err
+ }
+
+ if len(user) > 0 {
+ return Err3PIDInUse
+ }
+
+ return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
+ })
+}
+
+// RemoveThreePIDAssociation removes the association involving a given third-party
+// identifier.
+// If no association exists involving this third-party identifier, returns nothing.
+// If there was a problem talking to the database, returns an error.
+func (d *Database) RemoveThreePIDAssociation(
+ ctx context.Context, threepid string, medium string,
+) (err error) {
+ return d.threepids.deleteThreePID(ctx, threepid, medium)
+}
+
+// GetLocalpartForThreePID looks up the localpart associated with a given third-party
+// identifier.
+// If no association involves the given third-party idenfitier, returns an empty
+// string.
+// Returns an error if there was a problem talking to the database.
+func (d *Database) GetLocalpartForThreePID(
+ ctx context.Context, threepid string, medium string,
+) (localpart string, err error) {
+ return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
+}
+
+// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
+// a given local user.
+// If no association is known for this user, returns an empty slice.
+// Returns an error if there was an issue talking to the database.
+func (d *Database) GetThreePIDsForLocalpart(
+ ctx context.Context, localpart string,
+) (threepids []authtypes.ThreePID, err error) {
+ return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
+}
+
+// GetFilter looks up the filter associated with a given local user and filter ID.
+// Returns a filter structure. Otherwise returns an error if no such filter exists
+// or if there was an error talking to the database.
+func (d *Database) GetFilter(
+ ctx context.Context, localpart string, filterID string,
+) (*gomatrixserverlib.Filter, error) {
+ return d.filter.selectFilter(ctx, localpart, filterID)
+}
+
+// PutFilter puts the passed filter into the database.
+// Returns the filterID as a string. Otherwise returns an error if something
+// goes wrong.
+func (d *Database) PutFilter(
+ ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
+) (string, error) {
+ return d.filter.insertFilter(ctx, filter, localpart)
+}
+
+// CheckAccountAvailability checks if the username/localpart is already present
+// in the database.
+// If the DB returns sql.ErrNoRows the Localpart isn't taken.
+func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
+ _, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
+ if err == sql.ErrNoRows {
+ return true, nil
+ }
+ return false, err
+}
+
+// GetAccountByLocalpart returns the account associated with the given localpart.
+// This function assumes the request is authenticated or the account data is used only internally.
+// Returns sql.ErrNoRows if no account exists which matches the given localpart.
+func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
+) (*authtypes.Account, error) {
+ return d.accounts.selectAccountByLocalpart(ctx, localpart)
+}
diff --git a/clientapi/auth/storage/accounts/postgres/threepid_table.go b/clientapi/auth/storage/accounts/postgres/threepid_table.go
new file mode 100644
index 00000000..851b4a90
--- /dev/null
+++ b/clientapi/auth/storage/accounts/postgres/threepid_table.go
@@ -0,0 +1,129 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+)
+
+const threepidSchema = `
+-- Stores data about third party identifiers
+CREATE TABLE IF NOT EXISTS account_threepid (
+ -- The third party identifier
+ threepid TEXT NOT NULL,
+ -- The 3PID medium
+ medium TEXT NOT NULL DEFAULT 'email',
+ -- The localpart of the Matrix user ID associated to this 3PID
+ localpart TEXT NOT NULL,
+
+ PRIMARY KEY(threepid, medium)
+);
+
+CREATE INDEX IF NOT EXISTS account_threepid_localpart ON account_threepid(localpart);
+`
+
+const selectLocalpartForThreePIDSQL = "" +
+ "SELECT localpart FROM account_threepid WHERE threepid = $1 AND medium = $2"
+
+const selectThreePIDsForLocalpartSQL = "" +
+ "SELECT threepid, medium FROM account_threepid WHERE localpart = $1"
+
+const insertThreePIDSQL = "" +
+ "INSERT INTO account_threepid (threepid, medium, localpart) VALUES ($1, $2, $3)"
+
+const deleteThreePIDSQL = "" +
+ "DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
+
+type threepidStatements struct {
+ selectLocalpartForThreePIDStmt *sql.Stmt
+ selectThreePIDsForLocalpartStmt *sql.Stmt
+ insertThreePIDStmt *sql.Stmt
+ deleteThreePIDStmt *sql.Stmt
+}
+
+func (s *threepidStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(threepidSchema)
+ if err != nil {
+ return
+ }
+ if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil {
+ return
+ }
+ if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil {
+ return
+ }
+ if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil {
+ return
+ }
+ if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil {
+ return
+ }
+
+ return
+}
+
+func (s *threepidStatements) selectLocalpartForThreePID(
+ ctx context.Context, txn *sql.Tx, threepid string, medium string,
+) (localpart string, err error) {
+ stmt := common.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
+ err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return
+}
+
+func (s *threepidStatements) selectThreePIDsForLocalpart(
+ ctx context.Context, localpart string,
+) (threepids []authtypes.ThreePID, err error) {
+ rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
+ if err != nil {
+ return
+ }
+
+ threepids = []authtypes.ThreePID{}
+ for rows.Next() {
+ var threepid string
+ var medium string
+ if err = rows.Scan(&threepid, &medium); err != nil {
+ return
+ }
+ threepids = append(threepids, authtypes.ThreePID{
+ Address: threepid,
+ Medium: medium,
+ })
+ }
+
+ return
+}
+
+func (s *threepidStatements) insertThreePID(
+ ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
+) (err error) {
+ stmt := common.TxStmt(txn, s.insertThreePIDStmt)
+ _, err = stmt.ExecContext(ctx, threepid, medium, localpart)
+ return
+}
+
+func (s *threepidStatements) deleteThreePID(
+ ctx context.Context, threepid string, medium string) (err error) {
+ _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
+ return
+}
diff --git a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go
new file mode 100644
index 00000000..c2143881
--- /dev/null
+++ b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go
@@ -0,0 +1,141 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const accountDataSchema = `
+-- Stores data about accounts data.
+CREATE TABLE IF NOT EXISTS account_data (
+ -- The Matrix user ID localpart for this account
+ localpart TEXT NOT NULL,
+ -- The room ID for this data (empty string if not specific to a room)
+ room_id TEXT,
+ -- The account data type
+ type TEXT NOT NULL,
+ -- The account data content
+ content TEXT NOT NULL,
+
+ PRIMARY KEY(localpart, room_id, type)
+);
+`
+
+const insertAccountDataSQL = `
+ INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
+ ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
+`
+
+const selectAccountDataSQL = "" +
+ "SELECT room_id, type, content FROM account_data WHERE localpart = $1"
+
+const selectAccountDataByTypeSQL = "" +
+ "SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
+
+type accountDataStatements struct {
+ insertAccountDataStmt *sql.Stmt
+ selectAccountDataStmt *sql.Stmt
+ selectAccountDataByTypeStmt *sql.Stmt
+}
+
+func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(accountDataSchema)
+ if err != nil {
+ return
+ }
+ if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
+ return
+ }
+ if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil {
+ return
+ }
+ if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *accountDataStatements) insertAccountData(
+ ctx context.Context, localpart, roomID, dataType, content string,
+) (err error) {
+ stmt := s.insertAccountDataStmt
+ _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
+ return
+}
+
+func (s *accountDataStatements) selectAccountData(
+ ctx context.Context, localpart string,
+) (
+ global []gomatrixserverlib.ClientEvent,
+ rooms map[string][]gomatrixserverlib.ClientEvent,
+ err error,
+) {
+ rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
+ if err != nil {
+ return
+ }
+
+ global = []gomatrixserverlib.ClientEvent{}
+ rooms = make(map[string][]gomatrixserverlib.ClientEvent)
+
+ for rows.Next() {
+ var roomID string
+ var dataType string
+ var content []byte
+
+ if err = rows.Scan(&roomID, &dataType, &content); err != nil {
+ return
+ }
+
+ ac := gomatrixserverlib.ClientEvent{
+ Type: dataType,
+ Content: content,
+ }
+
+ if len(roomID) > 0 {
+ rooms[roomID] = append(rooms[roomID], ac)
+ } else {
+ global = append(global, ac)
+ }
+ }
+
+ return
+}
+
+func (s *accountDataStatements) selectAccountDataByType(
+ ctx context.Context, localpart, roomID, dataType string,
+) (data *gomatrixserverlib.ClientEvent, err error) {
+ stmt := s.selectAccountDataByTypeStmt
+ var content []byte
+
+ if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+
+ return
+ }
+
+ data = &gomatrixserverlib.ClientEvent{
+ Type: dataType,
+ Content: content,
+ }
+
+ return
+}
diff --git a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go
new file mode 100644
index 00000000..b029951f
--- /dev/null
+++ b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go
@@ -0,0 +1,151 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/clientapi/userutil"
+ "github.com/matrix-org/gomatrixserverlib"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const accountsSchema = `
+-- Stores data about accounts.
+CREATE TABLE IF NOT EXISTS account_accounts (
+ -- The Matrix user ID localpart for this account
+ localpart TEXT NOT NULL PRIMARY KEY,
+ -- When this account was first created, as a unix timestamp (ms resolution).
+ created_ts BIGINT NOT NULL,
+ -- The password hash for this account. Can be NULL if this is a passwordless account.
+ password_hash TEXT,
+ -- Identifies which application service this account belongs to, if any.
+ appservice_id TEXT
+ -- TODO:
+ -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
+);
+`
+
+const insertAccountSQL = "" +
+ "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+
+const selectAccountByLocalpartSQL = "" +
+ "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
+
+const selectPasswordHashSQL = "" +
+ "SELECT password_hash FROM account_accounts WHERE localpart = $1"
+
+const selectNewNumericLocalpartSQL = "" +
+ "SELECT COUNT(localpart) FROM account_accounts"
+
+// TODO: Update password
+
+type accountsStatements struct {
+ insertAccountStmt *sql.Stmt
+ selectAccountByLocalpartStmt *sql.Stmt
+ selectPasswordHashStmt *sql.Stmt
+ selectNewNumericLocalpartStmt *sql.Stmt
+ serverName gomatrixserverlib.ServerName
+}
+
+func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+ _, err = db.Exec(accountsSchema)
+ if err != nil {
+ return
+ }
+ if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
+ return
+ }
+ if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
+ return
+ }
+ if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
+ return
+ }
+ if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
+ return
+ }
+ s.serverName = server
+ return
+}
+
+// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
+// this account will be passwordless. Returns an error if this account already exists. Returns the account
+// on success.
+func (s *accountsStatements) insertAccount(
+ ctx context.Context, localpart, hash, appserviceID string,
+) (*authtypes.Account, error) {
+ createdTimeMS := time.Now().UnixNano() / 1000000
+ stmt := s.insertAccountStmt
+
+ var err error
+ if appserviceID == "" {
+ _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
+ } else {
+ _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ return &authtypes.Account{
+ Localpart: localpart,
+ UserID: userutil.MakeUserID(localpart, s.serverName),
+ ServerName: s.serverName,
+ AppServiceID: appserviceID,
+ }, nil
+}
+
+func (s *accountsStatements) selectPasswordHash(
+ ctx context.Context, localpart string,
+) (hash string, err error) {
+ err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
+ return
+}
+
+func (s *accountsStatements) selectAccountByLocalpart(
+ ctx context.Context, localpart string,
+) (*authtypes.Account, error) {
+ var appserviceIDPtr sql.NullString
+ var acc authtypes.Account
+
+ stmt := s.selectAccountByLocalpartStmt
+ err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
+ if err != nil {
+ if err != sql.ErrNoRows {
+ log.WithError(err).Error("Unable to retrieve user from the db")
+ }
+ return nil, err
+ }
+ if appserviceIDPtr.Valid {
+ acc.AppServiceID = appserviceIDPtr.String
+ }
+
+ acc.UserID = userutil.MakeUserID(localpart, s.serverName)
+ acc.ServerName = s.serverName
+
+ return &acc, nil
+}
+
+func (s *accountsStatements) selectNewNumericLocalpart(
+ ctx context.Context,
+) (id int64, err error) {
+ err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
+ return
+}
diff --git a/clientapi/auth/storage/accounts/sqlite3/filter_table.go b/clientapi/auth/storage/accounts/sqlite3/filter_table.go
new file mode 100644
index 00000000..691ead77
--- /dev/null
+++ b/clientapi/auth/storage/accounts/sqlite3/filter_table.go
@@ -0,0 +1,139 @@
+// Copyright 2017 Jan Christian Grünhage
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const filterSchema = `
+-- Stores data about filters
+CREATE TABLE IF NOT EXISTS account_filter (
+ -- The filter
+ filter TEXT NOT NULL,
+ -- The ID
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The localpart of the Matrix user ID associated to this filter
+ localpart TEXT NOT NULL,
+
+ UNIQUE (id, localpart)
+);
+
+CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
+`
+
+const selectFilterSQL = "" +
+ "SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2"
+
+const selectFilterIDByContentSQL = "" +
+ "SELECT id FROM account_filter WHERE localpart = $1 AND filter = $2"
+
+const insertFilterSQL = "" +
+ "INSERT INTO account_filter (filter, localpart) VALUES ($1, $2)"
+
+const selectLastInsertedFilterIDSQL = "" +
+ "SELECT id FROM account_filter WHERE rowid = last_insert_rowid()"
+
+type filterStatements struct {
+ selectFilterStmt *sql.Stmt
+ selectLastInsertedFilterIDStmt *sql.Stmt
+ selectFilterIDByContentStmt *sql.Stmt
+ insertFilterStmt *sql.Stmt
+}
+
+func (s *filterStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(filterSchema)
+ if err != nil {
+ return
+ }
+ if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
+ return
+ }
+ if s.selectLastInsertedFilterIDStmt, err = db.Prepare(selectLastInsertedFilterIDSQL); err != nil {
+ return
+ }
+ if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
+ return
+ }
+ if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *filterStatements) selectFilter(
+ ctx context.Context, localpart string, filterID string,
+) (*gomatrixserverlib.Filter, error) {
+ // Retrieve filter from database (stored as canonical JSON)
+ var filterData []byte
+ err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
+ if err != nil {
+ return nil, err
+ }
+
+ // Unmarshal JSON into Filter struct
+ var filter gomatrixserverlib.Filter
+ if err = json.Unmarshal(filterData, &filter); err != nil {
+ return nil, err
+ }
+ return &filter, nil
+}
+
+func (s *filterStatements) insertFilter(
+ ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
+) (filterID string, err error) {
+ var existingFilterID string
+
+ // Serialise json
+ filterJSON, err := json.Marshal(filter)
+ if err != nil {
+ return "", err
+ }
+ // Remove whitespaces and sort JSON data
+ // needed to prevent from inserting the same filter multiple times
+ filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
+ if err != nil {
+ return "", err
+ }
+
+ // Check if filter already exists in the database using its localpart and content
+ //
+ // This can result in a race condition when two clients try to insert the
+ // same filter and localpart at the same time, however this is not a
+ // problem as both calls will result in the same filterID
+ err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
+ localpart, filterJSON).Scan(&existingFilterID)
+ if err != nil && err != sql.ErrNoRows {
+ return "", err
+ }
+ // If it does, return the existing ID
+ if existingFilterID != "" {
+ return existingFilterID, err
+ }
+
+ // Otherwise insert the filter and return the new ID
+ if _, err = s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart); err != nil {
+ return "", err
+ }
+ row := s.selectLastInsertedFilterIDStmt.QueryRowContext(ctx)
+ if err := row.Scan(&filterID); err != nil {
+ return "", err
+ }
+ return
+}
diff --git a/clientapi/auth/storage/accounts/sqlite3/membership_table.go b/clientapi/auth/storage/accounts/sqlite3/membership_table.go
new file mode 100644
index 00000000..8e5e69ba
--- /dev/null
+++ b/clientapi/auth/storage/accounts/sqlite3/membership_table.go
@@ -0,0 +1,131 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+)
+
+const membershipSchema = `
+-- Stores data about users memberships to rooms.
+CREATE TABLE IF NOT EXISTS account_memberships (
+ -- The Matrix user ID localpart for the member
+ localpart TEXT NOT NULL,
+ -- The room this user is a member of
+ room_id TEXT NOT NULL,
+ -- The ID of the join membership event
+ event_id TEXT NOT NULL,
+
+ -- A user can only be member of a room once
+ PRIMARY KEY (localpart, room_id),
+
+ UNIQUE (event_id)
+);
+`
+
+const insertMembershipSQL = `
+ INSERT INTO account_memberships(localpart, room_id, event_id) VALUES ($1, $2, $3)
+ ON CONFLICT (localpart, room_id) DO UPDATE SET event_id = EXCLUDED.event_id
+`
+
+const selectMembershipsByLocalpartSQL = "" +
+ "SELECT room_id, event_id FROM account_memberships WHERE localpart = $1"
+
+const selectMembershipInRoomByLocalpartSQL = "" +
+ "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2"
+
+const deleteMembershipsByEventIDsSQL = "" +
+ "DELETE FROM account_memberships WHERE event_id IN ($1)"
+
+type membershipStatements struct {
+ deleteMembershipsByEventIDsStmt *sql.Stmt
+ insertMembershipStmt *sql.Stmt
+ selectMembershipInRoomByLocalpartStmt *sql.Stmt
+ selectMembershipsByLocalpartStmt *sql.Stmt
+}
+
+func (s *membershipStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(membershipSchema)
+ if err != nil {
+ return
+ }
+ if s.deleteMembershipsByEventIDsStmt, err = db.Prepare(deleteMembershipsByEventIDsSQL); err != nil {
+ return
+ }
+ if s.insertMembershipStmt, err = db.Prepare(insertMembershipSQL); err != nil {
+ return
+ }
+ if s.selectMembershipInRoomByLocalpartStmt, err = db.Prepare(selectMembershipInRoomByLocalpartSQL); err != nil {
+ return
+ }
+ if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *membershipStatements) insertMembership(
+ ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
+) (err error) {
+ stmt := txn.Stmt(s.insertMembershipStmt)
+ _, err = stmt.ExecContext(ctx, localpart, roomID, eventID)
+ return
+}
+
+func (s *membershipStatements) deleteMembershipsByEventIDs(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) (err error) {
+ stmt := txn.Stmt(s.deleteMembershipsByEventIDsStmt)
+ _, err = stmt.ExecContext(ctx, pq.StringArray(eventIDs))
+ return
+}
+
+func (s *membershipStatements) selectMembershipInRoomByLocalpart(
+ ctx context.Context, localpart, roomID string,
+) (authtypes.Membership, error) {
+ membership := authtypes.Membership{Localpart: localpart, RoomID: roomID}
+ stmt := s.selectMembershipInRoomByLocalpartStmt
+ err := stmt.QueryRowContext(ctx, localpart, roomID).Scan(&membership.EventID)
+
+ return membership, err
+}
+
+func (s *membershipStatements) selectMembershipsByLocalpart(
+ ctx context.Context, localpart string,
+) (memberships []authtypes.Membership, err error) {
+ stmt := s.selectMembershipsByLocalpartStmt
+ rows, err := stmt.QueryContext(ctx, localpart)
+ if err != nil {
+ return
+ }
+
+ memberships = []authtypes.Membership{}
+
+ defer rows.Close() // nolint: errcheck
+ for rows.Next() {
+ var m authtypes.Membership
+ m.Localpart = localpart
+ if err := rows.Scan(&m.RoomID, &m.EventID); err != nil {
+ return nil, err
+ }
+ memberships = append(memberships, m)
+ }
+
+ return
+}
diff --git a/clientapi/auth/storage/accounts/sqlite3/profile_table.go b/clientapi/auth/storage/accounts/sqlite3/profile_table.go
new file mode 100644
index 00000000..7af8307e
--- /dev/null
+++ b/clientapi/auth/storage/accounts/sqlite3/profile_table.go
@@ -0,0 +1,107 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+)
+
+const profilesSchema = `
+-- Stores data about accounts profiles.
+CREATE TABLE IF NOT EXISTS account_profiles (
+ -- The Matrix user ID localpart for this account
+ localpart TEXT NOT NULL PRIMARY KEY,
+ -- The display name for this account
+ display_name TEXT,
+ -- The URL of the avatar for this account
+ avatar_url TEXT
+);
+`
+
+const insertProfileSQL = "" +
+ "INSERT INTO account_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
+
+const selectProfileByLocalpartSQL = "" +
+ "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart = $1"
+
+const setAvatarURLSQL = "" +
+ "UPDATE account_profiles SET avatar_url = $1 WHERE localpart = $2"
+
+const setDisplayNameSQL = "" +
+ "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
+
+type profilesStatements struct {
+ insertProfileStmt *sql.Stmt
+ selectProfileByLocalpartStmt *sql.Stmt
+ setAvatarURLStmt *sql.Stmt
+ setDisplayNameStmt *sql.Stmt
+}
+
+func (s *profilesStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(profilesSchema)
+ if err != nil {
+ return
+ }
+ if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil {
+ return
+ }
+ if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil {
+ return
+ }
+ if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil {
+ return
+ }
+ if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *profilesStatements) insertProfile(
+ ctx context.Context, localpart string,
+) (err error) {
+ _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
+ return
+}
+
+func (s *profilesStatements) selectProfileByLocalpart(
+ ctx context.Context, localpart string,
+) (*authtypes.Profile, error) {
+ var profile authtypes.Profile
+ err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
+ &profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
+ )
+ if err != nil {
+ return nil, err
+ }
+ return &profile, nil
+}
+
+func (s *profilesStatements) setAvatarURL(
+ ctx context.Context, localpart string, avatarURL string,
+) (err error) {
+ _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
+ return
+}
+
+func (s *profilesStatements) setDisplayName(
+ ctx context.Context, localpart string, displayName string,
+) (err error) {
+ _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
+ return
+}
diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go
new file mode 100644
index 00000000..199c4606
--- /dev/null
+++ b/clientapi/auth/storage/accounts/sqlite3/storage.go
@@ -0,0 +1,392 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/gomatrixserverlib"
+ "golang.org/x/crypto/bcrypt"
+
+ // Import the postgres database driver.
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// Database represents an account database
+type Database struct {
+ db *sql.DB
+ common.PartitionOffsetStatements
+ accounts accountsStatements
+ profiles profilesStatements
+ memberships membershipStatements
+ accountDatas accountDataStatements
+ threepids threepidStatements
+ filter filterStatements
+ serverName gomatrixserverlib.ServerName
+}
+
+// NewDatabase creates a new accounts and profiles database
+func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
+ var db *sql.DB
+ var err error
+ if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
+ return nil, err
+ }
+ partitions := common.PartitionOffsetStatements{}
+ if err = partitions.Prepare(db, "account"); err != nil {
+ return nil, err
+ }
+ a := accountsStatements{}
+ if err = a.prepare(db, serverName); err != nil {
+ return nil, err
+ }
+ p := profilesStatements{}
+ if err = p.prepare(db); err != nil {
+ return nil, err
+ }
+ m := membershipStatements{}
+ if err = m.prepare(db); err != nil {
+ return nil, err
+ }
+ ac := accountDataStatements{}
+ if err = ac.prepare(db); err != nil {
+ return nil, err
+ }
+ t := threepidStatements{}
+ if err = t.prepare(db); err != nil {
+ return nil, err
+ }
+ f := filterStatements{}
+ if err = f.prepare(db); err != nil {
+ return nil, err
+ }
+ return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
+}
+
+// GetAccountByPassword returns the account associated with the given localpart and password.
+// Returns sql.ErrNoRows if no account exists which matches the given localpart.
+func (d *Database) GetAccountByPassword(
+ ctx context.Context, localpart, plaintextPassword string,
+) (*authtypes.Account, error) {
+ hash, err := d.accounts.selectPasswordHash(ctx, localpart)
+ if err != nil {
+ return nil, err
+ }
+ if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
+ return nil, err
+ }
+ return d.accounts.selectAccountByLocalpart(ctx, localpart)
+}
+
+// GetProfileByLocalpart returns the profile associated with the given localpart.
+// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
+func (d *Database) GetProfileByLocalpart(
+ ctx context.Context, localpart string,
+) (*authtypes.Profile, error) {
+ return d.profiles.selectProfileByLocalpart(ctx, localpart)
+}
+
+// SetAvatarURL updates the avatar URL of the profile associated with the given
+// localpart. Returns an error if something went wrong with the SQL query
+func (d *Database) SetAvatarURL(
+ ctx context.Context, localpart string, avatarURL string,
+) error {
+ return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
+}
+
+// SetDisplayName updates the display name of the profile associated with the given
+// localpart. Returns an error if something went wrong with the SQL query
+func (d *Database) SetDisplayName(
+ ctx context.Context, localpart string, displayName string,
+) error {
+ return d.profiles.setDisplayName(ctx, localpart, displayName)
+}
+
+// CreateAccount makes a new account with the given login name and password, and creates an empty profile
+// for this account. If no password is supplied, the account will be a passwordless account. If the
+// account already exists, it will return nil, nil.
+func (d *Database) CreateAccount(
+ ctx context.Context, localpart, plaintextPassword, appserviceID string,
+) (*authtypes.Account, error) {
+ var err error
+
+ // Generate a password hash if this is not a password-less user
+ hash := ""
+ if plaintextPassword != "" {
+ hash, err = hashPassword(plaintextPassword)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if err := d.profiles.insertProfile(ctx, localpart); err != nil {
+ if common.IsUniqueConstraintViolationErr(err) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
+ "global": {
+ "content": [],
+ "override": [],
+ "room": [],
+ "sender": [],
+ "underride": []
+ }
+ }`); err != nil {
+ return nil, err
+ }
+ return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
+}
+
+// SaveMembership saves the user matching a given localpart as a member of a given
+// room. It also stores the ID of the membership event.
+// If a membership already exists between the user and the room, or if the
+// insert fails, returns the SQL error
+func (d *Database) saveMembership(
+ ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
+) error {
+ return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
+}
+
+// removeMembershipsByEventIDs removes the memberships corresponding to the
+// `join` membership events IDs in the eventIDs slice.
+// If the removal fails, or if there is no membership to remove, returns an error
+func (d *Database) removeMembershipsByEventIDs(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) error {
+ return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
+}
+
+// UpdateMemberships adds the "join" membership events included in a given state
+// events array, and removes those which ID is included in a given array of events
+// IDs. All of the process is run in a transaction, which commits only once/if every
+// insertion and deletion has been successfully processed.
+// Returns a SQL error if there was an issue with any part of the process
+func (d *Database) UpdateMemberships(
+ ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
+ return err
+ }
+
+ for _, event := range eventsToAdd {
+ if err := d.newMembership(ctx, txn, event); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+// GetMembershipInRoomByLocalpart returns the membership for an user
+// matching the given localpart if he is a member of the room matching roomID,
+// if not sql.ErrNoRows is returned.
+// If there was an issue during the retrieval, returns the SQL error
+func (d *Database) GetMembershipInRoomByLocalpart(
+ ctx context.Context, localpart, roomID string,
+) (authtypes.Membership, error) {
+ return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
+}
+
+// GetMembershipsByLocalpart returns an array containing the memberships for all
+// the rooms a user matching a given localpart is a member of
+// If no membership match the given localpart, returns an empty array
+// If there was an issue during the retrieval, returns the SQL error
+func (d *Database) GetMembershipsByLocalpart(
+ ctx context.Context, localpart string,
+) (memberships []authtypes.Membership, err error) {
+ return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
+}
+
+// newMembership saves a new membership in the database.
+// If the event isn't a valid m.room.member event with type `join`, does nothing.
+// If an error occurred, returns the SQL error
+func (d *Database) newMembership(
+ ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
+) error {
+ if ev.Type() == "m.room.member" && ev.StateKey() != nil {
+ localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
+ if err != nil {
+ return err
+ }
+
+ // We only want state events from local users
+ if string(serverName) != string(d.serverName) {
+ return nil
+ }
+
+ eventID := ev.EventID()
+ roomID := ev.RoomID()
+ membership, err := ev.Membership()
+ if err != nil {
+ return err
+ }
+
+ // Only "join" membership events can be considered as new memberships
+ if membership == gomatrixserverlib.Join {
+ if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// SaveAccountData saves new account data for a given user and a given room.
+// If the account data is not specific to a room, the room ID should be an empty string
+// If an account data already exists for a given set (user, room, data type), it will
+// update the corresponding row with the new content
+// Returns a SQL error if there was an issue with the insertion/update
+func (d *Database) SaveAccountData(
+ ctx context.Context, localpart, roomID, dataType, content string,
+) error {
+ return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
+}
+
+// GetAccountData returns account data related to a given localpart
+// If no account data could be found, returns an empty arrays
+// Returns an error if there was an issue with the retrieval
+func (d *Database) GetAccountData(ctx context.Context, localpart string) (
+ global []gomatrixserverlib.ClientEvent,
+ rooms map[string][]gomatrixserverlib.ClientEvent,
+ err error,
+) {
+ return d.accountDatas.selectAccountData(ctx, localpart)
+}
+
+// GetAccountDataByType returns account data matching a given
+// localpart, room ID and type.
+// If no account data could be found, returns nil
+// Returns an error if there was an issue with the retrieval
+func (d *Database) GetAccountDataByType(
+ ctx context.Context, localpart, roomID, dataType string,
+) (data *gomatrixserverlib.ClientEvent, err error) {
+ return d.accountDatas.selectAccountDataByType(
+ ctx, localpart, roomID, dataType,
+ )
+}
+
+// GetNewNumericLocalpart generates and returns a new unused numeric localpart
+func (d *Database) GetNewNumericLocalpart(
+ ctx context.Context,
+) (int64, error) {
+ return d.accounts.selectNewNumericLocalpart(ctx)
+}
+
+func hashPassword(plaintext string) (hash string, err error) {
+ hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
+ return string(hashBytes), err
+}
+
+// Err3PIDInUse is the error returned when trying to save an association involving
+// a third-party identifier which is already associated to a local user.
+var Err3PIDInUse = errors.New("This third-party identifier is already in use")
+
+// SaveThreePIDAssociation saves the association between a third party identifier
+// and a local Matrix user (identified by the user's ID's local part).
+// If the third-party identifier is already part of an association, returns Err3PIDInUse.
+// Returns an error if there was a problem talking to the database.
+func (d *Database) SaveThreePIDAssociation(
+ ctx context.Context, threepid, localpart, medium string,
+) (err error) {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ user, err := d.threepids.selectLocalpartForThreePID(
+ ctx, txn, threepid, medium,
+ )
+ if err != nil {
+ return err
+ }
+
+ if len(user) > 0 {
+ return Err3PIDInUse
+ }
+
+ return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
+ })
+}
+
+// RemoveThreePIDAssociation removes the association involving a given third-party
+// identifier.
+// If no association exists involving this third-party identifier, returns nothing.
+// If there was a problem talking to the database, returns an error.
+func (d *Database) RemoveThreePIDAssociation(
+ ctx context.Context, threepid string, medium string,
+) (err error) {
+ return d.threepids.deleteThreePID(ctx, threepid, medium)
+}
+
+// GetLocalpartForThreePID looks up the localpart associated with a given third-party
+// identifier.
+// If no association involves the given third-party idenfitier, returns an empty
+// string.
+// Returns an error if there was a problem talking to the database.
+func (d *Database) GetLocalpartForThreePID(
+ ctx context.Context, threepid string, medium string,
+) (localpart string, err error) {
+ return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
+}
+
+// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
+// a given local user.
+// If no association is known for this user, returns an empty slice.
+// Returns an error if there was an issue talking to the database.
+func (d *Database) GetThreePIDsForLocalpart(
+ ctx context.Context, localpart string,
+) (threepids []authtypes.ThreePID, err error) {
+ return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
+}
+
+// GetFilter looks up the filter associated with a given local user and filter ID.
+// Returns a filter structure. Otherwise returns an error if no such filter exists
+// or if there was an error talking to the database.
+func (d *Database) GetFilter(
+ ctx context.Context, localpart string, filterID string,
+) (*gomatrixserverlib.Filter, error) {
+ return d.filter.selectFilter(ctx, localpart, filterID)
+}
+
+// PutFilter puts the passed filter into the database.
+// Returns the filterID as a string. Otherwise returns an error if something
+// goes wrong.
+func (d *Database) PutFilter(
+ ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
+) (string, error) {
+ return d.filter.insertFilter(ctx, filter, localpart)
+}
+
+// CheckAccountAvailability checks if the username/localpart is already present
+// in the database.
+// If the DB returns sql.ErrNoRows the Localpart isn't taken.
+func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
+ _, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
+ if err == sql.ErrNoRows {
+ return true, nil
+ }
+ return false, err
+}
+
+// GetAccountByLocalpart returns the account associated with the given localpart.
+// This function assumes the request is authenticated or the account data is used only internally.
+// Returns sql.ErrNoRows if no account exists which matches the given localpart.
+func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
+) (*authtypes.Account, error) {
+ return d.accounts.selectAccountByLocalpart(ctx, localpart)
+}
diff --git a/clientapi/auth/storage/accounts/threepid_table.go b/clientapi/auth/storage/accounts/sqlite3/threepid_table.go
index a03aa4f8..53f6408d 100644
--- a/clientapi/auth/storage/accounts/threepid_table.go
+++ b/clientapi/auth/storage/accounts/sqlite3/threepid_table.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package sqlite3
import (
"context"
diff --git a/clientapi/auth/storage/accounts/storage.go b/clientapi/auth/storage/accounts/storage.go
index 7cfc63c0..1dfd5f1f 100644
--- a/clientapi/auth/storage/accounts/storage.go
+++ b/clientapi/auth/storage/accounts/storage.go
@@ -1,392 +1,56 @@
-// Copyright 2017 Vector Creations Ltd
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
package accounts
import (
"context"
- "database/sql"
"errors"
+ "net/url"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/postgres"
+ "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/sqlite3"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
- "golang.org/x/crypto/bcrypt"
-
- // Import the postgres database driver.
- _ "github.com/lib/pq"
)
-// Database represents an account database
-type Database struct {
- db *sql.DB
- common.PartitionOffsetStatements
- accounts accountsStatements
- profiles profilesStatements
- memberships membershipStatements
- accountDatas accountDataStatements
- threepids threepidStatements
- filter filterStatements
- serverName gomatrixserverlib.ServerName
-}
-
-// NewDatabase creates a new accounts and profiles database
-func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
- var db *sql.DB
- var err error
- if db, err = sql.Open("postgres", dataSourceName); err != nil {
- return nil, err
- }
- partitions := common.PartitionOffsetStatements{}
- if err = partitions.Prepare(db, "account"); err != nil {
- return nil, err
- }
- a := accountsStatements{}
- if err = a.prepare(db, serverName); err != nil {
- return nil, err
- }
- p := profilesStatements{}
- if err = p.prepare(db); err != nil {
- return nil, err
- }
- m := membershipStatements{}
- if err = m.prepare(db); err != nil {
- return nil, err
- }
- ac := accountDataStatements{}
- if err = ac.prepare(db); err != nil {
- return nil, err
- }
- t := threepidStatements{}
- if err = t.prepare(db); err != nil {
- return nil, err
- }
- f := filterStatements{}
- if err = f.prepare(db); err != nil {
- return nil, err
- }
- return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
-}
-
-// GetAccountByPassword returns the account associated with the given localpart and password.
-// Returns sql.ErrNoRows if no account exists which matches the given localpart.
-func (d *Database) GetAccountByPassword(
- ctx context.Context, localpart, plaintextPassword string,
-) (*authtypes.Account, error) {
- hash, err := d.accounts.selectPasswordHash(ctx, localpart)
+type Database interface {
+ common.PartitionStorer
+ GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error)
+ GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
+ SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
+ SetDisplayName(ctx context.Context, localpart string, displayName string) error
+ CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
+ UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
+ GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
+ GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
+ SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
+ GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
+ GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
+ GetNewNumericLocalpart(ctx context.Context) (int64, error)
+ SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
+ RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
+ GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
+ GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
+ GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
+ PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
+ CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
+ GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error)
+}
+
+func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
+ uri, err := url.Parse(dataSourceName)
if err != nil {
- return nil, err
+ return postgres.NewDatabase(dataSourceName, serverName)
}
- if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
- return nil, err
+ switch uri.Scheme {
+ case "postgres":
+ return postgres.NewDatabase(dataSourceName, serverName)
+ case "file":
+ return sqlite3.NewDatabase(dataSourceName, serverName)
+ default:
+ return postgres.NewDatabase(dataSourceName, serverName)
}
- return d.accounts.selectAccountByLocalpart(ctx, localpart)
-}
-
-// GetProfileByLocalpart returns the profile associated with the given localpart.
-// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
-func (d *Database) GetProfileByLocalpart(
- ctx context.Context, localpart string,
-) (*authtypes.Profile, error) {
- return d.profiles.selectProfileByLocalpart(ctx, localpart)
-}
-
-// SetAvatarURL updates the avatar URL of the profile associated with the given
-// localpart. Returns an error if something went wrong with the SQL query
-func (d *Database) SetAvatarURL(
- ctx context.Context, localpart string, avatarURL string,
-) error {
- return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
-}
-
-// SetDisplayName updates the display name of the profile associated with the given
-// localpart. Returns an error if something went wrong with the SQL query
-func (d *Database) SetDisplayName(
- ctx context.Context, localpart string, displayName string,
-) error {
- return d.profiles.setDisplayName(ctx, localpart, displayName)
-}
-
-// CreateAccount makes a new account with the given login name and password, and creates an empty profile
-// for this account. If no password is supplied, the account will be a passwordless account. If the
-// account already exists, it will return nil, nil.
-func (d *Database) CreateAccount(
- ctx context.Context, localpart, plaintextPassword, appserviceID string,
-) (*authtypes.Account, error) {
- var err error
-
- // Generate a password hash if this is not a password-less user
- hash := ""
- if plaintextPassword != "" {
- hash, err = hashPassword(plaintextPassword)
- if err != nil {
- return nil, err
- }
- }
- if err := d.profiles.insertProfile(ctx, localpart); err != nil {
- if common.IsUniqueConstraintViolationErr(err) {
- return nil, nil
- }
- return nil, err
- }
- if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
- "global": {
- "content": [],
- "override": [],
- "room": [],
- "sender": [],
- "underride": []
- }
- }`); err != nil {
- return nil, err
- }
- return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
-}
-
-// SaveMembership saves the user matching a given localpart as a member of a given
-// room. It also stores the ID of the membership event.
-// If a membership already exists between the user and the room, or if the
-// insert fails, returns the SQL error
-func (d *Database) saveMembership(
- ctx context.Context, txn *sql.Tx, localpart, roomID, eventID string,
-) error {
- return d.memberships.insertMembership(ctx, txn, localpart, roomID, eventID)
-}
-
-// removeMembershipsByEventIDs removes the memberships corresponding to the
-// `join` membership events IDs in the eventIDs slice.
-// If the removal fails, or if there is no membership to remove, returns an error
-func (d *Database) removeMembershipsByEventIDs(
- ctx context.Context, txn *sql.Tx, eventIDs []string,
-) error {
- return d.memberships.deleteMembershipsByEventIDs(ctx, txn, eventIDs)
-}
-
-// UpdateMemberships adds the "join" membership events included in a given state
-// events array, and removes those which ID is included in a given array of events
-// IDs. All of the process is run in a transaction, which commits only once/if every
-// insertion and deletion has been successfully processed.
-// Returns a SQL error if there was an issue with any part of the process
-func (d *Database) UpdateMemberships(
- ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string,
-) error {
- return common.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.removeMembershipsByEventIDs(ctx, txn, idsToRemove); err != nil {
- return err
- }
-
- for _, event := range eventsToAdd {
- if err := d.newMembership(ctx, txn, event); err != nil {
- return err
- }
- }
-
- return nil
- })
-}
-
-// GetMembershipInRoomByLocalpart returns the membership for an user
-// matching the given localpart if he is a member of the room matching roomID,
-// if not sql.ErrNoRows is returned.
-// If there was an issue during the retrieval, returns the SQL error
-func (d *Database) GetMembershipInRoomByLocalpart(
- ctx context.Context, localpart, roomID string,
-) (authtypes.Membership, error) {
- return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID)
-}
-
-// GetMembershipsByLocalpart returns an array containing the memberships for all
-// the rooms a user matching a given localpart is a member of
-// If no membership match the given localpart, returns an empty array
-// If there was an issue during the retrieval, returns the SQL error
-func (d *Database) GetMembershipsByLocalpart(
- ctx context.Context, localpart string,
-) (memberships []authtypes.Membership, err error) {
- return d.memberships.selectMembershipsByLocalpart(ctx, localpart)
-}
-
-// newMembership saves a new membership in the database.
-// If the event isn't a valid m.room.member event with type `join`, does nothing.
-// If an error occurred, returns the SQL error
-func (d *Database) newMembership(
- ctx context.Context, txn *sql.Tx, ev gomatrixserverlib.Event,
-) error {
- if ev.Type() == "m.room.member" && ev.StateKey() != nil {
- localpart, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
- if err != nil {
- return err
- }
-
- // We only want state events from local users
- if string(serverName) != string(d.serverName) {
- return nil
- }
-
- eventID := ev.EventID()
- roomID := ev.RoomID()
- membership, err := ev.Membership()
- if err != nil {
- return err
- }
-
- // Only "join" membership events can be considered as new memberships
- if membership == gomatrixserverlib.Join {
- if err := d.saveMembership(ctx, txn, localpart, roomID, eventID); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-// SaveAccountData saves new account data for a given user and a given room.
-// If the account data is not specific to a room, the room ID should be an empty string
-// If an account data already exists for a given set (user, room, data type), it will
-// update the corresponding row with the new content
-// Returns a SQL error if there was an issue with the insertion/update
-func (d *Database) SaveAccountData(
- ctx context.Context, localpart, roomID, dataType, content string,
-) error {
- return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
-}
-
-// GetAccountData returns account data related to a given localpart
-// If no account data could be found, returns an empty arrays
-// Returns an error if there was an issue with the retrieval
-func (d *Database) GetAccountData(ctx context.Context, localpart string) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
- err error,
-) {
- return d.accountDatas.selectAccountData(ctx, localpart)
-}
-
-// GetAccountDataByType returns account data matching a given
-// localpart, room ID and type.
-// If no account data could be found, returns nil
-// Returns an error if there was an issue with the retrieval
-func (d *Database) GetAccountDataByType(
- ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
- return d.accountDatas.selectAccountDataByType(
- ctx, localpart, roomID, dataType,
- )
-}
-
-// GetNewNumericLocalpart generates and returns a new unused numeric localpart
-func (d *Database) GetNewNumericLocalpart(
- ctx context.Context,
-) (int64, error) {
- return d.accounts.selectNewNumericLocalpart(ctx)
-}
-
-func hashPassword(plaintext string) (hash string, err error) {
- hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost)
- return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("This third-party identifier is already in use")
-
-// SaveThreePIDAssociation saves the association between a third party identifier
-// and a local Matrix user (identified by the user's ID's local part).
-// If the third-party identifier is already part of an association, returns Err3PIDInUse.
-// Returns an error if there was a problem talking to the database.
-func (d *Database) SaveThreePIDAssociation(
- ctx context.Context, threepid, localpart, medium string,
-) (err error) {
- return common.WithTransaction(d.db, func(txn *sql.Tx) error {
- user, err := d.threepids.selectLocalpartForThreePID(
- ctx, txn, threepid, medium,
- )
- if err != nil {
- return err
- }
-
- if len(user) > 0 {
- return Err3PIDInUse
- }
-
- return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
- })
-}
-
-// RemoveThreePIDAssociation removes the association involving a given third-party
-// identifier.
-// If no association exists involving this third-party identifier, returns nothing.
-// If there was a problem talking to the database, returns an error.
-func (d *Database) RemoveThreePIDAssociation(
- ctx context.Context, threepid string, medium string,
-) (err error) {
- return d.threepids.deleteThreePID(ctx, threepid, medium)
-}
-
-// GetLocalpartForThreePID looks up the localpart associated with a given third-party
-// identifier.
-// If no association involves the given third-party idenfitier, returns an empty
-// string.
-// Returns an error if there was a problem talking to the database.
-func (d *Database) GetLocalpartForThreePID(
- ctx context.Context, threepid string, medium string,
-) (localpart string, err error) {
- return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
-}
-
-// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
-// a given local user.
-// If no association is known for this user, returns an empty slice.
-// Returns an error if there was an issue talking to the database.
-func (d *Database) GetThreePIDsForLocalpart(
- ctx context.Context, localpart string,
-) (threepids []authtypes.ThreePID, err error) {
- return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
-}
-
-// GetFilter looks up the filter associated with a given local user and filter ID.
-// Returns a filter structure. Otherwise returns an error if no such filter exists
-// or if there was an error talking to the database.
-func (d *Database) GetFilter(
- ctx context.Context, localpart string, filterID string,
-) (*gomatrixserverlib.Filter, error) {
- return d.filter.selectFilter(ctx, localpart, filterID)
-}
-
-// PutFilter puts the passed filter into the database.
-// Returns the filterID as a string. Otherwise returns an error if something
-// goes wrong.
-func (d *Database) PutFilter(
- ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
-) (string, error) {
- return d.filter.insertFilter(ctx, filter, localpart)
-}
-
-// CheckAccountAvailability checks if the username/localpart is already present
-// in the database.
-// If the DB returns sql.ErrNoRows the Localpart isn't taken.
-func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
- _, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
- if err == sql.ErrNoRows {
- return true, nil
- }
- return false, err
-}
-
-// GetAccountByLocalpart returns the account associated with the given localpart.
-// This function assumes the request is authenticated or the account data is used only internally.
-// Returns sql.ErrNoRows if no account exists which matches the given localpart.
-func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
-) (*authtypes.Account, error) {
- return d.accounts.selectAccountByLocalpart(ctx, localpart)
-}
diff --git a/clientapi/auth/storage/devices/devices_table.go b/clientapi/auth/storage/devices/postgres/devices_table.go
index 99741247..349bf1ef 100644
--- a/clientapi/auth/storage/devices/devices_table.go
+++ b/clientapi/auth/storage/devices/postgres/devices_table.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package devices
+package postgres
import (
"context"
diff --git a/clientapi/auth/storage/devices/postgres/storage.go b/clientapi/auth/storage/devices/postgres/storage.go
new file mode 100644
index 00000000..221c3998
--- /dev/null
+++ b/clientapi/auth/storage/devices/postgres/storage.go
@@ -0,0 +1,182 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "crypto/rand"
+ "database/sql"
+ "encoding/base64"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// The length of generated device IDs
+var deviceIDByteLength = 6
+
+// Database represents a device database.
+type Database struct {
+ db *sql.DB
+ devices devicesStatements
+}
+
+// NewDatabase creates a new device database
+func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
+ var db *sql.DB
+ var err error
+ if db, err = sql.Open("postgres", dataSourceName); err != nil {
+ return nil, err
+ }
+ d := devicesStatements{}
+ if err = d.prepare(db, serverName); err != nil {
+ return nil, err
+ }
+ return &Database{db, d}, nil
+}
+
+// GetDeviceByAccessToken returns the device matching the given access token.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByAccessToken(
+ ctx context.Context, token string,
+) (*authtypes.Device, error) {
+ return d.devices.selectDeviceByToken(ctx, token)
+}
+
+// GetDeviceByID returns the device matching the given ID.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByID(
+ ctx context.Context, localpart, deviceID string,
+) (*authtypes.Device, error) {
+ return d.devices.selectDeviceByID(ctx, localpart, deviceID)
+}
+
+// GetDevicesByLocalpart returns the devices matching the given localpart.
+func (d *Database) GetDevicesByLocalpart(
+ ctx context.Context, localpart string,
+) ([]authtypes.Device, error) {
+ return d.devices.selectDevicesByLocalpart(ctx, localpart)
+}
+
+// CreateDevice makes a new device associated with the given user ID localpart.
+// If there is already a device with the same device ID for this user, that access token will be revoked
+// and replaced with the given accessToken. If the given accessToken is already in use for another device,
+// an error will be returned.
+// If no device ID is given one is generated.
+// Returns the device on success.
+func (d *Database) CreateDevice(
+ ctx context.Context, localpart string, deviceID *string, accessToken string,
+ displayName *string,
+) (dev *authtypes.Device, returnErr error) {
+ if deviceID != nil {
+ returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ // Revoke existing tokens for this device
+ if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
+ return err
+ }
+
+ dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
+ return err
+ })
+ } else {
+ // We generate device IDs in a loop in case its already taken.
+ // We cap this at going round 5 times to ensure we don't spin forever
+ var newDeviceID string
+ for i := 1; i <= 5; i++ {
+ newDeviceID, returnErr = generateDeviceID()
+ if returnErr != nil {
+ return
+ }
+
+ returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
+ return err
+ })
+ if returnErr == nil {
+ return
+ }
+ }
+ }
+ return
+}
+
+// generateDeviceID creates a new device id. Returns an error if failed to generate
+// random bytes.
+func generateDeviceID() (string, error) {
+ b := make([]byte, deviceIDByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ // url-safe no padding
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// UpdateDevice updates the given device with the display name.
+// Returns SQL error if there are problems and nil on success.
+func (d *Database) UpdateDevice(
+ ctx context.Context, localpart, deviceID string, displayName *string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
+ })
+}
+
+// RemoveDevice revokes a device by deleting the entry in the database
+// matching with the given device ID and user ID localpart.
+// If the device doesn't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevice(
+ ctx context.Context, deviceID, localpart string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveDevices revokes one or more devices by deleting the entry in the database
+// matching with the given device IDs and user ID localpart.
+// If the devices don't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevices(
+ ctx context.Context, localpart string, devices []string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveAllDevices revokes devices by deleting the entry in the
+// database matching the given user ID localpart.
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveAllDevices(
+ ctx context.Context, localpart string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
diff --git a/clientapi/auth/storage/devices/sqlite3/devices_table.go b/clientapi/auth/storage/devices/sqlite3/devices_table.go
new file mode 100644
index 00000000..dc88890d
--- /dev/null
+++ b/clientapi/auth/storage/devices/sqlite3/devices_table.go
@@ -0,0 +1,243 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/common"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/clientapi/userutil"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const devicesSchema = `
+-- This sequence is used for automatic allocation of session_id.
+-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
+
+-- Stores data about devices.
+CREATE TABLE IF NOT EXISTS device_devices (
+ access_token TEXT PRIMARY KEY,
+ session_id INTEGER,
+ device_id TEXT ,
+ localpart TEXT ,
+ created_ts BIGINT,
+ display_name TEXT,
+
+ UNIQUE (localpart, device_id)
+);
+`
+
+const insertDeviceSQL = "" +
+ "INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)"
+
+const selectDevicesCountSQL = "" +
+ "SELECT COUNT(access_token) FROM device_devices"
+
+const selectDeviceByTokenSQL = "" +
+ "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
+
+const selectDeviceByIDSQL = "" +
+ "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
+
+const selectDevicesByLocalpartSQL = "" +
+ "SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
+
+const updateDeviceNameSQL = "" +
+ "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
+
+const deleteDeviceSQL = "" +
+ "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
+
+const deleteDevicesByLocalpartSQL = "" +
+ "DELETE FROM device_devices WHERE localpart = $1"
+
+const deleteDevicesSQL = "" +
+ "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
+
+type devicesStatements struct {
+ db *sql.DB
+ insertDeviceStmt *sql.Stmt
+ selectDevicesCountStmt *sql.Stmt
+ selectDeviceByTokenStmt *sql.Stmt
+ selectDeviceByIDStmt *sql.Stmt
+ selectDevicesByLocalpartStmt *sql.Stmt
+ updateDeviceNameStmt *sql.Stmt
+ deleteDeviceStmt *sql.Stmt
+ deleteDevicesByLocalpartStmt *sql.Stmt
+ serverName gomatrixserverlib.ServerName
+}
+
+func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+ s.db = db
+ _, err = db.Exec(devicesSchema)
+ if err != nil {
+ return
+ }
+ if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
+ return
+ }
+ if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
+ return
+ }
+ if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
+ return
+ }
+ if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
+ return
+ }
+ if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
+ return
+ }
+ if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
+ return
+ }
+ if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
+ return
+ }
+ if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
+ return
+ }
+ s.serverName = server
+ return
+}
+
+// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
+// Returns an error if the user already has a device with the given device ID.
+// Returns the device on success.
+func (s *devicesStatements) insertDevice(
+ ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
+ displayName *string,
+) (*authtypes.Device, error) {
+ createdTimeMS := time.Now().UnixNano() / 1000000
+ var sessionID int64
+ countStmt := common.TxStmt(txn, s.selectDevicesCountStmt)
+ insertStmt := common.TxStmt(txn, s.insertDeviceStmt)
+ if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
+ return nil, err
+ }
+ sessionID++
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
+ return nil, err
+ }
+ return &authtypes.Device{
+ ID: id,
+ UserID: userutil.MakeUserID(localpart, s.serverName),
+ AccessToken: accessToken,
+ SessionID: sessionID,
+ }, nil
+}
+
+func (s *devicesStatements) deleteDevice(
+ ctx context.Context, txn *sql.Tx, id, localpart string,
+) error {
+ stmt := common.TxStmt(txn, s.deleteDeviceStmt)
+ _, err := stmt.ExecContext(ctx, id, localpart)
+ return err
+}
+
+func (s *devicesStatements) deleteDevices(
+ ctx context.Context, txn *sql.Tx, localpart string, devices []string,
+) error {
+ orig := strings.Replace(deleteDevicesSQL, "($1)", common.QueryVariadic(len(devices)), 1)
+ prep, err := s.db.Prepare(orig)
+ if err != nil {
+ return err
+ }
+ stmt := common.TxStmt(txn, prep)
+ params := make([]interface{}, len(devices)+1)
+ params[0] = localpart
+ for i, v := range devices {
+ params[i+1] = v
+ }
+ params = append(params, params...)
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
+}
+
+func (s *devicesStatements) deleteDevicesByLocalpart(
+ ctx context.Context, txn *sql.Tx, localpart string,
+) error {
+ stmt := common.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
+ _, err := stmt.ExecContext(ctx, localpart)
+ return err
+}
+
+func (s *devicesStatements) updateDeviceName(
+ ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
+) error {
+ stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ return err
+}
+
+func (s *devicesStatements) selectDeviceByToken(
+ ctx context.Context, accessToken string,
+) (*authtypes.Device, error) {
+ var dev authtypes.Device
+ var localpart string
+ stmt := s.selectDeviceByTokenStmt
+ err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
+ if err == nil {
+ dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ dev.AccessToken = accessToken
+ }
+ return &dev, err
+}
+
+// selectDeviceByID retrieves a device from the database with the given user
+// localpart and deviceID
+func (s *devicesStatements) selectDeviceByID(
+ ctx context.Context, localpart, deviceID string,
+) (*authtypes.Device, error) {
+ var dev authtypes.Device
+ var created sql.NullInt64
+ stmt := s.selectDeviceByIDStmt
+ err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&created)
+ if err == nil {
+ dev.ID = deviceID
+ dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ }
+ return &dev, err
+}
+
+func (s *devicesStatements) selectDevicesByLocalpart(
+ ctx context.Context, localpart string,
+) ([]authtypes.Device, error) {
+ devices := []authtypes.Device{}
+
+ rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
+
+ if err != nil {
+ return devices, err
+ }
+
+ for rows.Next() {
+ var dev authtypes.Device
+ err = rows.Scan(&dev.ID)
+ if err != nil {
+ return devices, err
+ }
+ dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ devices = append(devices, dev)
+ }
+
+ return devices, nil
+}
diff --git a/clientapi/auth/storage/devices/sqlite3/storage.go b/clientapi/auth/storage/devices/sqlite3/storage.go
new file mode 100644
index 00000000..e1ce6f00
--- /dev/null
+++ b/clientapi/auth/storage/devices/sqlite3/storage.go
@@ -0,0 +1,184 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "crypto/rand"
+ "database/sql"
+ "encoding/base64"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/gomatrixserverlib"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// The length of generated device IDs
+var deviceIDByteLength = 6
+
+// Database represents a device database.
+type Database struct {
+ db *sql.DB
+ devices devicesStatements
+}
+
+// NewDatabase creates a new device database
+func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
+ var db *sql.DB
+ var err error
+ if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
+ return nil, err
+ }
+ d := devicesStatements{}
+ if err = d.prepare(db, serverName); err != nil {
+ return nil, err
+ }
+ return &Database{db, d}, nil
+}
+
+// GetDeviceByAccessToken returns the device matching the given access token.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByAccessToken(
+ ctx context.Context, token string,
+) (*authtypes.Device, error) {
+ return d.devices.selectDeviceByToken(ctx, token)
+}
+
+// GetDeviceByID returns the device matching the given ID.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByID(
+ ctx context.Context, localpart, deviceID string,
+) (*authtypes.Device, error) {
+ return d.devices.selectDeviceByID(ctx, localpart, deviceID)
+}
+
+// GetDevicesByLocalpart returns the devices matching the given localpart.
+func (d *Database) GetDevicesByLocalpart(
+ ctx context.Context, localpart string,
+) ([]authtypes.Device, error) {
+ return d.devices.selectDevicesByLocalpart(ctx, localpart)
+}
+
+// CreateDevice makes a new device associated with the given user ID localpart.
+// If there is already a device with the same device ID for this user, that access token will be revoked
+// and replaced with the given accessToken. If the given accessToken is already in use for another device,
+// an error will be returned.
+// If no device ID is given one is generated.
+// Returns the device on success.
+func (d *Database) CreateDevice(
+ ctx context.Context, localpart string, deviceID *string, accessToken string,
+ displayName *string,
+) (dev *authtypes.Device, returnErr error) {
+ if deviceID != nil {
+ returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ // Revoke existing tokens for this device
+ if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
+ return err
+ }
+
+ dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
+ return err
+ })
+ } else {
+ // We generate device IDs in a loop in case its already taken.
+ // We cap this at going round 5 times to ensure we don't spin forever
+ var newDeviceID string
+ for i := 1; i <= 5; i++ {
+ newDeviceID, returnErr = generateDeviceID()
+ if returnErr != nil {
+ return
+ }
+
+ returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
+ return err
+ })
+ if returnErr == nil {
+ return
+ }
+ }
+ }
+ return
+}
+
+// generateDeviceID creates a new device id. Returns an error if failed to generate
+// random bytes.
+func generateDeviceID() (string, error) {
+ b := make([]byte, deviceIDByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ // url-safe no padding
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// UpdateDevice updates the given device with the display name.
+// Returns SQL error if there are problems and nil on success.
+func (d *Database) UpdateDevice(
+ ctx context.Context, localpart, deviceID string, displayName *string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
+ })
+}
+
+// RemoveDevice revokes a device by deleting the entry in the database
+// matching with the given device ID and user ID localpart.
+// If the device doesn't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevice(
+ ctx context.Context, deviceID, localpart string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveDevices revokes one or more devices by deleting the entry in the database
+// matching with the given device IDs and user ID localpart.
+// If the devices don't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevices(
+ ctx context.Context, localpart string, devices []string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveAllDevices revokes devices by deleting the entry in the
+// database matching the given user ID localpart.
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveAllDevices(
+ ctx context.Context, localpart string,
+) error {
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
diff --git a/clientapi/auth/storage/devices/storage.go b/clientapi/auth/storage/devices/storage.go
index 150180c1..82f75640 100644
--- a/clientapi/auth/storage/devices/storage.go
+++ b/clientapi/auth/storage/devices/storage.go
@@ -1,182 +1,37 @@
-// Copyright 2017 Vector Creations Ltd
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
package devices
import (
"context"
- "crypto/rand"
- "database/sql"
- "encoding/base64"
+ "net/url"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
- "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/postgres"
+ "github.com/matrix-org/dendrite/clientapi/auth/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
-// The length of generated device IDs
-var deviceIDByteLength = 6
-
-// Database represents a device database.
-type Database struct {
- db *sql.DB
- devices devicesStatements
-}
-
-// NewDatabase creates a new device database
-func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
- var db *sql.DB
- var err error
- if db, err = sql.Open("postgres", dataSourceName); err != nil {
- return nil, err
- }
- d := devicesStatements{}
- if err = d.prepare(db, serverName); err != nil {
- return nil, err
- }
- return &Database{db, d}, nil
-}
-
-// GetDeviceByAccessToken returns the device matching the given access token.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByAccessToken(
- ctx context.Context, token string,
-) (*authtypes.Device, error) {
- return d.devices.selectDeviceByToken(ctx, token)
-}
-
-// GetDeviceByID returns the device matching the given ID.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByID(
- ctx context.Context, localpart, deviceID string,
-) (*authtypes.Device, error) {
- return d.devices.selectDeviceByID(ctx, localpart, deviceID)
-}
-
-// GetDevicesByLocalpart returns the devices matching the given localpart.
-func (d *Database) GetDevicesByLocalpart(
- ctx context.Context, localpart string,
-) ([]authtypes.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, localpart)
-}
-
-// CreateDevice makes a new device associated with the given user ID localpart.
-// If there is already a device with the same device ID for this user, that access token will be revoked
-// and replaced with the given accessToken. If the given accessToken is already in use for another device,
-// an error will be returned.
-// If no device ID is given one is generated.
-// Returns the device on success.
-func (d *Database) CreateDevice(
- ctx context.Context, localpart string, deviceID *string, accessToken string,
- displayName *string,
-) (dev *authtypes.Device, returnErr error) {
- if deviceID != nil {
- returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
- var err error
- // Revoke existing tokens for this device
- if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
- return err
- }
-
- dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
- return err
- })
- } else {
- // We generate device IDs in a loop in case its already taken.
- // We cap this at going round 5 times to ensure we don't spin forever
- var newDeviceID string
- for i := 1; i <= 5; i++ {
- newDeviceID, returnErr = generateDeviceID()
- if returnErr != nil {
- return
- }
-
- returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
- var err error
- dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
- return err
- })
- if returnErr == nil {
- return
- }
- }
- }
- return
+type Database interface {
+ GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
+ GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
+ GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
+ CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
+ UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
+ RemoveDevice(ctx context.Context, deviceID, localpart string) error
+ RemoveDevices(ctx context.Context, localpart string, devices []string) error
+ RemoveAllDevices(ctx context.Context, localpart string) error
}
-// generateDeviceID creates a new device id. Returns an error if failed to generate
-// random bytes.
-func generateDeviceID() (string, error) {
- b := make([]byte, deviceIDByteLength)
- _, err := rand.Read(b)
+func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (Database, error) {
+ uri, err := url.Parse(dataSourceName)
if err != nil {
- return "", err
+ return postgres.NewDatabase(dataSourceName, serverName)
+ }
+ switch uri.Scheme {
+ case "postgres":
+ return postgres.NewDatabase(dataSourceName, serverName)
+ case "file":
+ return sqlite3.NewDatabase(dataSourceName, serverName)
+ default:
+ return postgres.NewDatabase(dataSourceName, serverName)
}
- // url-safe no padding
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// UpdateDevice updates the given device with the display name.
-// Returns SQL error if there are problems and nil on success.
-func (d *Database) UpdateDevice(
- ctx context.Context, localpart, deviceID string, displayName *string,
-) error {
- return common.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
- })
-}
-
-// RemoveDevice revokes a device by deleting the entry in the database
-// matching with the given device ID and user ID localpart.
-// If the device doesn't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevice(
- ctx context.Context, deviceID, localpart string,
-) error {
- return common.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveDevices revokes one or more devices by deleting the entry in the database
-// matching with the given device IDs and user ID localpart.
-// If the devices don't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevices(
- ctx context.Context, localpart string, devices []string,
-) error {
- return common.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveAllDevices revokes devices by deleting the entry in the
-// database matching the given user ID localpart.
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart string,
-) error {
- return common.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
}
diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go
index c911fecc..bb44e016 100644
--- a/clientapi/clientapi.go
+++ b/clientapi/clientapi.go
@@ -34,8 +34,8 @@ import (
// component.
func SetupClientAPIComponent(
base *basecomponent.BaseDendrite,
- deviceDB *devices.Database,
- accountsDB *accounts.Database,
+ deviceDB devices.Database,
+ accountsDB accounts.Database,
federation *gomatrixserverlib.FederationClient,
keyRing *gomatrixserverlib.KeyRing,
aliasAPI roomserverAPI.RoomserverAliasAPI,
diff --git a/clientapi/consumers/roomserver.go b/clientapi/consumers/roomserver.go
index 0ee7c6bf..a6528151 100644
--- a/clientapi/consumers/roomserver.go
+++ b/clientapi/consumers/roomserver.go
@@ -31,7 +31,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer
- db *accounts.Database
+ db accounts.Database
query api.RoomserverQueryAPI
serverName string
}
@@ -40,7 +40,7 @@ type OutputRoomEventConsumer struct {
func NewOutputRoomEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
- store *accounts.Database,
+ store accounts.Database,
queryAPI api.RoomserverQueryAPI,
) *OutputRoomEventConsumer {
diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go
index bbc8c258..8ae9de2d 100644
--- a/clientapi/routing/account_data.go
+++ b/clientapi/routing/account_data.go
@@ -30,7 +30,7 @@ import (
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
func GetAccountData(
- req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
+ req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string,
) util.JSONResponse {
if userID != device.UserID {
@@ -62,7 +62,7 @@ func GetAccountData(
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData(
- req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
+ req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if userID != device.UserID {
diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go
index f6f06421..2b1245b9 100644
--- a/clientapi/routing/createroom.go
+++ b/clientapi/routing/createroom.go
@@ -135,7 +135,7 @@ type fledglingEvent struct {
func CreateRoom(
req *http.Request, device *authtypes.Device,
cfg *config.Dendrite, producer *producers.RoomserverProducer,
- accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
+ accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
// TODO (#267): Check room ID doesn't clash with an existing one, and we
@@ -149,7 +149,7 @@ func CreateRoom(
func createRoom(
req *http.Request, device *authtypes.Device,
cfg *config.Dendrite, roomID string, producer *producers.RoomserverProducer,
- accountDB *accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
+ accountDB accounts.Database, aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
logger := util.GetLogger(req.Context())
diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go
index eb7cd0b0..9b8647cd 100644
--- a/clientapi/routing/device.go
+++ b/clientapi/routing/device.go
@@ -46,7 +46,7 @@ type devicesDeleteJSON struct {
// GetDeviceByID handles /devices/{deviceID}
func GetDeviceByID(
- req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
+ req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@@ -76,7 +76,7 @@ func GetDeviceByID(
// GetDevicesByLocalpart handles /devices
func GetDevicesByLocalpart(
- req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
+ req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@@ -107,7 +107,7 @@ func GetDevicesByLocalpart(
// UpdateDeviceByID handles PUT on /devices/{deviceID}
func UpdateDeviceByID(
- req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
+ req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@@ -153,7 +153,7 @@ func UpdateDeviceByID(
// DeleteDeviceById handles DELETE requests to /devices/{deviceId}
func DeleteDeviceById(
- req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
+ req *http.Request, deviceDB devices.Database, device *authtypes.Device,
deviceID string,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
@@ -176,7 +176,7 @@ func DeleteDeviceById(
// DeleteDevices handles POST requests to /delete_devices
func DeleteDevices(
- req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
+ req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
diff --git a/clientapi/routing/filter.go b/clientapi/routing/filter.go
index eec501ff..583b2395 100644
--- a/clientapi/routing/filter.go
+++ b/clientapi/routing/filter.go
@@ -27,7 +27,7 @@ import (
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
func GetFilter(
- req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string,
+ req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string, filterID string,
) util.JSONResponse {
if userID != device.UserID {
return util.JSONResponse{
@@ -63,7 +63,7 @@ type filterResponse struct {
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
func PutFilter(
- req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string,
+ req *http.Request, device *authtypes.Device, accountDB accounts.Database, userID string,
) util.JSONResponse {
if userID != device.UserID {
return util.JSONResponse{
diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go
index 8b3f3740..5e6f3e55 100644
--- a/clientapi/routing/joinroom.go
+++ b/clientapi/routing/joinroom.go
@@ -45,7 +45,7 @@ func JoinRoomByIDOrAlias(
queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI,
keyRing gomatrixserverlib.KeyRing,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
) util.JSONResponse {
var content map[string]interface{} // must be a JSON object
if resErr := httputil.UnmarshalJSONRequest(req, &content); resErr != nil {
diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go
index 2f4fb83c..b8364ed9 100644
--- a/clientapi/routing/login.go
+++ b/clientapi/routing/login.go
@@ -70,7 +70,7 @@ func passwordLogin() loginFlows {
// Login implements GET and POST /login
func Login(
- req *http.Request, accountDB *accounts.Database, deviceDB *devices.Database,
+ req *http.Request, accountDB accounts.Database, deviceDB devices.Database,
cfg *config.Dendrite,
) util.JSONResponse {
if req.Method == http.MethodGet { // TODO: support other forms of login other than password, depending on config options
@@ -153,7 +153,7 @@ func Login(
func getDevice(
ctx context.Context,
r passwordRequest,
- deviceDB *devices.Database,
+ deviceDB devices.Database,
acc *authtypes.Account,
token string,
) (dev *authtypes.Device, err error) {
diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go
index 3294fbcd..0ac9ca4a 100644
--- a/clientapi/routing/logout.go
+++ b/clientapi/routing/logout.go
@@ -26,7 +26,7 @@ import (
// Logout handles POST /logout
func Logout(
- req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
+ req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@@ -45,7 +45,7 @@ func Logout(
// LogoutAll handles POST /logout/all
func LogoutAll(
- req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
+ req *http.Request, deviceDB devices.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go
index 8b8b3a0f..68c131a2 100644
--- a/clientapi/routing/membership.go
+++ b/clientapi/routing/membership.go
@@ -40,7 +40,7 @@ var errMissingUserID = errors.New("'user_id' must be supplied")
// SendMembership implements PUT /rooms/{roomID}/(join|kick|ban|unban|leave|invite)
// by building a m.room.member event then sending it to the room server
func SendMembership(
- req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
+ req *http.Request, accountDB accounts.Database, device *authtypes.Device,
roomID string, membership string, cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
producer *producers.RoomserverProducer,
@@ -116,7 +116,7 @@ func SendMembership(
func buildMembershipEvent(
ctx context.Context,
- body threepid.MembershipRequest, accountDB *accounts.Database,
+ body threepid.MembershipRequest, accountDB accounts.Database,
device *authtypes.Device,
membership, roomID string,
cfg *config.Dendrite, evTime time.Time,
@@ -166,7 +166,7 @@ func loadProfile(
ctx context.Context,
userID string,
cfg *config.Dendrite,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
asAPI appserviceAPI.AppServiceQueryAPI,
) (*authtypes.Profile, error) {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
@@ -216,7 +216,7 @@ func checkAndProcessThreepid(
body *threepid.MembershipRequest,
cfg *config.Dendrite,
queryAPI roomserverAPI.RoomserverQueryAPI,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
producer *producers.RoomserverProducer,
membership, roomID string,
evTime time.Time,
diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go
index 4688b19e..9b091ddf 100644
--- a/clientapi/routing/profile.go
+++ b/clientapi/routing/profile.go
@@ -36,7 +36,7 @@ import (
// GetProfile implements GET /profile/{userID}
func GetProfile(
- req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
+ req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
@@ -64,7 +64,7 @@ func GetProfile(
// GetAvatarURL implements GET /profile/{userID}/avatar_url
func GetAvatarURL(
- req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
+ req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@@ -90,7 +90,7 @@ func GetAvatarURL(
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
func SetAvatarURL(
- req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
+ req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse {
@@ -170,7 +170,7 @@ func SetAvatarURL(
// GetDisplayName implements GET /profile/{userID}/displayname
func GetDisplayName(
- req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite,
+ req *http.Request, accountDB accounts.Database, cfg *config.Dendrite,
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
) util.JSONResponse {
@@ -196,7 +196,7 @@ func GetDisplayName(
// SetDisplayName implements PUT /profile/{userID}/displayname
func SetDisplayName(
- req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
+ req *http.Request, accountDB accounts.Database, device *authtypes.Device,
userID string, producer *producers.UserUpdateProducer, cfg *config.Dendrite,
rsProducer *producers.RoomserverProducer, queryAPI api.RoomserverQueryAPI,
) util.JSONResponse {
@@ -279,7 +279,7 @@ func SetDisplayName(
// Returns an error when something goes wrong or specifically
// common.ErrProfileNoExists when the profile doesn't exist.
func getProfile(
- ctx context.Context, accountDB *accounts.Database, cfg *config.Dendrite,
+ ctx context.Context, accountDB accounts.Database, cfg *config.Dendrite,
userID string,
asAPI appserviceAPI.AppServiceQueryAPI,
federation *gomatrixserverlib.FederationClient,
diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go
index 4375faaf..9d67d998 100644
--- a/clientapi/routing/register.go
+++ b/clientapi/routing/register.go
@@ -440,8 +440,8 @@ func validateApplicationService(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
func Register(
req *http.Request,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
cfg *config.Dendrite,
) util.JSONResponse {
@@ -513,8 +513,8 @@ func handleGuestRegistration(
req *http.Request,
r registerRequest,
cfg *config.Dendrite,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
) util.JSONResponse {
//Generate numeric local part for guest user
@@ -570,8 +570,8 @@ func handleRegistrationFlow(
r registerRequest,
sessionID string,
cfg *config.Dendrite,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
) util.JSONResponse {
// TODO: Shared secret registration (create new user scripts)
// TODO: Enable registration config flag
@@ -668,8 +668,8 @@ func handleApplicationServiceRegistration(
req *http.Request,
r registerRequest,
cfg *config.Dendrite,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
) util.JSONResponse {
// Check if we previously had issues extracting the access token from the
// request.
@@ -707,8 +707,8 @@ func checkAndCompleteFlow(
r registerRequest,
sessionID string,
cfg *config.Dendrite,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
) util.JSONResponse {
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue
@@ -730,8 +730,8 @@ func checkAndCompleteFlow(
// LegacyRegister process register requests from the legacy v1 API
func LegacyRegister(
req *http.Request,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
cfg *config.Dendrite,
) util.JSONResponse {
var r legacyRegisterRequest
@@ -814,8 +814,8 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
// not all
func completeRegistration(
ctx context.Context,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
username, password, appserviceID string,
inhibitLogin common.WeakBoolean,
displayName, deviceID *string,
@@ -992,7 +992,7 @@ type availableResponse struct {
func RegisterAvailable(
req *http.Request,
cfg *config.Dendrite,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
) util.JSONResponse {
username := req.URL.Query().Get("username")
diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go
index 487081c5..aa5f13c4 100644
--- a/clientapi/routing/room_tagging.go
+++ b/clientapi/routing/room_tagging.go
@@ -40,7 +40,7 @@ func newTag() gomatrix.TagContent {
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags(
req *http.Request,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
@@ -77,7 +77,7 @@ func GetTags(
// the tag to the "map" and saving the new "map" to the DB
func PutTag(
req *http.Request,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
@@ -134,7 +134,7 @@ func PutTag(
// the "map" and then saving the new "map" in the DB
func DeleteTag(
req *http.Request,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
device *authtypes.Device,
userID string,
roomID string,
@@ -203,7 +203,7 @@ func obtainSavedTags(
req *http.Request,
userID string,
roomID string,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
) (string, *gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
@@ -222,7 +222,7 @@ func saveTagData(
req *http.Request,
localpart string,
roomID string,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
Tag gomatrix.TagContent,
) error {
newTagData, err := json.Marshal(Tag)
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index f519523a..f0841b79 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -52,8 +52,8 @@ func Setup(
queryAPI roomserverAPI.RoomserverQueryAPI,
aliasAPI roomserverAPI.RoomserverAliasAPI,
asAPI appserviceAPI.AppServiceQueryAPI,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.KeyRing,
userUpdateProducer *producers.UserUpdateProducer,
diff --git a/clientapi/routing/sendtyping.go b/clientapi/routing/sendtyping.go
index 561a2d89..db3ab28b 100644
--- a/clientapi/routing/sendtyping.go
+++ b/clientapi/routing/sendtyping.go
@@ -34,7 +34,7 @@ type typingContentJSON struct {
// sends the typing events to client API typingProducer
func SendTyping(
req *http.Request, device *authtypes.Device, roomID string,
- userID string, accountDB *accounts.Database,
+ userID string, accountDB accounts.Database,
typingProducer *producers.TypingServerProducer,
) util.JSONResponse {
if device.UserID != userID {
diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go
index 88b02fe4..69383cdf 100644
--- a/clientapi/routing/threepid.go
+++ b/clientapi/routing/threepid.go
@@ -39,7 +39,7 @@ type threePIDsResponse struct {
// RequestEmailToken implements:
// POST /account/3pid/email/requestToken
// POST /register/email/requestToken
-func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg *config.Dendrite) util.JSONResponse {
+func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.Dendrite) util.JSONResponse {
var body threepid.EmailAssociationRequest
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr
@@ -82,7 +82,7 @@ func RequestEmailToken(req *http.Request, accountDB *accounts.Database, cfg *con
// CheckAndSave3PIDAssociation implements POST /account/3pid
func CheckAndSave3PIDAssociation(
- req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
+ req *http.Request, accountDB accounts.Database, device *authtypes.Device,
cfg *config.Dendrite,
) util.JSONResponse {
var body threepid.EmailAssociationCheckRequest
@@ -142,7 +142,7 @@ func CheckAndSave3PIDAssociation(
// GetAssociated3PIDs implements GET /account/3pid
func GetAssociated3PIDs(
- req *http.Request, accountDB *accounts.Database, device *authtypes.Device,
+ req *http.Request, accountDB accounts.Database, device *authtypes.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
@@ -161,7 +161,7 @@ func GetAssociated3PIDs(
}
// Forget3PID implements POST /account/3pid/delete
-func Forget3PID(req *http.Request, accountDB *accounts.Database) util.JSONResponse {
+func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
var body authtypes.ThreePID
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
return *reqErr
diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go
index 2cf88d6e..aa54aa9f 100644
--- a/clientapi/threepid/invites.go
+++ b/clientapi/threepid/invites.go
@@ -87,7 +87,7 @@ var (
func CheckAndProcessInvite(
ctx context.Context,
device *authtypes.Device, body *MembershipRequest, cfg *config.Dendrite,
- queryAPI api.RoomserverQueryAPI, db *accounts.Database,
+ queryAPI api.RoomserverQueryAPI, db accounts.Database,
producer *producers.RoomserverProducer, membership string, roomID string,
evTime time.Time,
) (inviteStoredOnIDServer bool, err error) {
@@ -137,7 +137,7 @@ func CheckAndProcessInvite(
// Returns an error if a check or a request failed.
func queryIDServer(
ctx context.Context,
- db *accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
+ db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string,
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
if err = isTrusted(body.IDServer, cfg); err != nil {
@@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
// Returns an error if the request failed to send or if the response couldn't be parsed.
func queryIDServerStoreInvite(
ctx context.Context,
- db *accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
+ db accounts.Database, cfg *config.Dendrite, device *authtypes.Device,
body *MembershipRequest, roomID string,
) (*idServerStoreInviteResponse, error) {
// Retrieve the sender's profile to get their display name
diff --git a/cmd/kafka-producer/main.go b/cmd/kafka-producer/main.go
index 8a4340f2..f5f243e4 100644
--- a/cmd/kafka-producer/main.go
+++ b/cmd/kafka-producer/main.go
@@ -21,7 +21,7 @@ import (
"os"
"strings"
- "github.com/Shopify/sarama"
+ sarama "gopkg.in/Shopify/sarama.v1"
)
const usage = `Usage: %s
diff --git a/common/basecomponent/base.go b/common/basecomponent/base.go
index dc917ffe..4274de2b 100644
--- a/common/basecomponent/base.go
+++ b/common/basecomponent/base.go
@@ -18,6 +18,7 @@ import (
"database/sql"
"io"
"net/http"
+ "net/url"
"golang.org/x/crypto/ed25519"
@@ -68,7 +69,13 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite {
logrus.WithError(err).Panicf("failed to start opentracing")
}
- kafkaConsumer, kafkaProducer := setupKafka(cfg)
+ var kafkaConsumer sarama.Consumer
+ var kafkaProducer sarama.SyncProducer
+ if cfg.Kafka.UseNaffka {
+ kafkaConsumer, kafkaProducer = setupNaffka(cfg)
+ } else {
+ kafkaConsumer, kafkaProducer = setupKafka(cfg)
+ }
return &BaseDendrite{
componentName: componentName,
@@ -118,7 +125,7 @@ func (b *BaseDendrite) CreateHTTPFederationSenderAPIs() federationSenderAPI.Fede
// CreateDeviceDB creates a new instance of the device database. Should only be
// called once per component.
-func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
+func (b *BaseDendrite) CreateDeviceDB() devices.Database {
db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to devices db")
@@ -129,7 +136,7 @@ func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
-func (b *BaseDendrite) CreateAccountsDB() *accounts.Database {
+func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
@@ -186,37 +193,58 @@ func (b *BaseDendrite) SetupAndServeHTTP(bindaddr string, listenaddr string) {
logrus.Infof("Stopped %s server on %s", b.componentName, addr)
}
-// setupKafka creates kafka consumer/producer pair from the config. Checks if
-// should use naffka.
+// setupKafka creates kafka consumer/producer pair from the config.
func setupKafka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
- if cfg.Kafka.UseNaffka {
- db, err := sql.Open("postgres", string(cfg.Database.Naffka))
+ consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
+ if err != nil {
+ logrus.WithError(err).Panic("failed to start kafka consumer")
+ }
+
+ producer, err := sarama.NewSyncProducer(cfg.Kafka.Addresses, nil)
+ if err != nil {
+ logrus.WithError(err).Panic("failed to setup kafka producers")
+ }
+
+ return consumer, producer
+}
+
+// setupNaffka creates kafka consumer/producer pair from the config.
+func setupNaffka(cfg *config.Dendrite) (sarama.Consumer, sarama.SyncProducer) {
+ var err error
+ var db *sql.DB
+ var naffkaDB *naffka.DatabaseImpl
+
+ uri, err := url.Parse(string(cfg.Database.Naffka))
+ if err != nil || uri.Scheme == "file" {
+ db, err = sql.Open("sqlite3", string(cfg.Database.Naffka))
if err != nil {
logrus.WithError(err).Panic("Failed to open naffka database")
}
- naffkaDB, err := naffka.NewPostgresqlDatabase(db)
+ naffkaDB, err = naffka.NewSqliteDatabase(db)
if err != nil {
logrus.WithError(err).Panic("Failed to setup naffka database")
}
-
- naff, err := naffka.New(naffkaDB)
+ } else {
+ db, err = sql.Open("postgres", string(cfg.Database.Naffka))
if err != nil {
- logrus.WithError(err).Panic("Failed to setup naffka")
+ logrus.WithError(err).Panic("Failed to open naffka database")
}
- return naff, naff
+ naffkaDB, err = naffka.NewPostgresqlDatabase(db)
+ if err != nil {
+ logrus.WithError(err).Panic("Failed to setup naffka database")
+ }
}
- consumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
- if err != nil {
- logrus.WithError(err).Panic("failed to start kafka consumer")
+ if naffkaDB == nil {
+ panic("naffka connection string not understood")
}
- producer, err := sarama.NewSyncProducer(cfg.Kafka.Addresses, nil)
+ naff, err := naffka.New(naffkaDB)
if err != nil {
- logrus.WithError(err).Panic("failed to setup kafka producers")
+ logrus.WithError(err).Panic("Failed to setup naffka")
}
- return consumer, producer
+ return naff, naff
}
diff --git a/common/keydb/keydb.go b/common/keydb/keydb.go
index d1f2b7eb..cf15c9f0 100644
--- a/common/keydb/keydb.go
+++ b/common/keydb/keydb.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/crypto/ed25519"
"github.com/matrix-org/dendrite/common/keydb/postgres"
+ "github.com/matrix-org/dendrite/common/keydb/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -44,6 +45,8 @@ func NewDatabase(
switch uri.Scheme {
case "postgres":
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
+ case "file":
+ return sqlite3.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
default:
return postgres.NewDatabase(dataSourceName, serverName, serverKey, serverKeyID)
}
diff --git a/common/keydb/sqlite3/keydb.go b/common/keydb/sqlite3/keydb.go
new file mode 100644
index 00000000..88eb9d9f
--- /dev/null
+++ b/common/keydb/sqlite3/keydb.go
@@ -0,0 +1,115 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "math"
+
+ "golang.org/x/crypto/ed25519"
+
+ "github.com/matrix-org/gomatrixserverlib"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// A Database implements gomatrixserverlib.KeyDatabase and is used to store
+// the public keys for other matrix servers.
+type Database struct {
+ statements serverKeyStatements
+}
+
+// NewDatabase prepares a new key database.
+// It creates the necessary tables if they don't already exist.
+// It prepares all the SQL statements that it will use.
+// Returns an error if there was a problem talking to the database.
+func NewDatabase(
+ dataSourceName string,
+ serverName gomatrixserverlib.ServerName,
+ serverKey ed25519.PublicKey,
+ serverKeyID gomatrixserverlib.KeyID,
+) (*Database, error) {
+ db, err := sql.Open("sqlite3", dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ d := &Database{}
+ err = d.statements.prepare(db)
+ if err != nil {
+ return nil, err
+ }
+ // Store our own keys so that we don't end up making HTTP requests to find our
+ // own keys
+ index := gomatrixserverlib.PublicKeyLookupRequest{
+ ServerName: serverName,
+ KeyID: serverKeyID,
+ }
+ value := gomatrixserverlib.PublicKeyLookupResult{
+ VerifyKey: gomatrixserverlib.VerifyKey{
+ Key: gomatrixserverlib.Base64String(serverKey),
+ },
+ ValidUntilTS: math.MaxUint64 >> 1,
+ ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
+ }
+ err = d.StoreKeys(
+ context.Background(),
+ map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{
+ index: value,
+ },
+ )
+ if err != nil {
+ return nil, err
+ }
+ return d, nil
+}
+
+// FetcherName implements KeyFetcher
+func (d Database) FetcherName() string {
+ return "KeyDatabase"
+}
+
+// FetchKeys implements gomatrixserverlib.KeyDatabase
+func (d *Database) FetchKeys(
+ ctx context.Context,
+ requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
+) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
+ return d.statements.bulkSelectServerKeys(ctx, requests)
+}
+
+// StoreKeys implements gomatrixserverlib.KeyDatabase
+func (d *Database) StoreKeys(
+ ctx context.Context,
+ keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
+) error {
+ // TODO: Inserting all the keys within a single transaction may
+ // be more efficient since the transaction overhead can be quite
+ // high for a single insert statement.
+ var lastErr error
+ for request, keys := range keyMap {
+ if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
+ // Rather than returning immediately on error we try to insert the
+ // remaining keys.
+ // Since we are inserting the keys outside of a transaction it is
+ // possible for some of the inserts to succeed even though some
+ // of the inserts have failed.
+ // Ensuring that we always insert all the keys we can means that
+ // this behaviour won't depend on the iteration order of the map.
+ lastErr = err
+ }
+ }
+ return lastErr
+}
diff --git a/common/keydb/sqlite3/server_key_table.go b/common/keydb/sqlite3/server_key_table.go
new file mode 100644
index 00000000..6c33f30a
--- /dev/null
+++ b/common/keydb/sqlite3/server_key_table.go
@@ -0,0 +1,142 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const serverKeysSchema = `
+-- A cache of signing keys downloaded from remote servers.
+CREATE TABLE IF NOT EXISTS keydb_server_keys (
+ -- The name of the matrix server the key is for.
+ server_name TEXT NOT NULL,
+ -- The ID of the server key.
+ server_key_id TEXT NOT NULL,
+ -- Combined server name and key ID separated by the ASCII unit separator
+ -- to make it easier to run bulk queries.
+ server_name_and_key_id TEXT NOT NULL,
+ -- When the key is valid until as a millisecond timestamp.
+ -- 0 if this is an expired key (in which case expired_ts will be non-zero)
+ valid_until_ts BIGINT NOT NULL,
+ -- When the key expired as a millisecond timestamp.
+ -- 0 if this is an active key (in which case valid_until_ts will be non-zero)
+ expired_ts BIGINT NOT NULL,
+ -- The base64-encoded public key.
+ server_key TEXT NOT NULL,
+ UNIQUE (server_name, server_key_id)
+);
+
+CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
+`
+
+const bulkSelectServerKeysSQL = "" +
+ "SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
+ " server_key FROM keydb_server_keys" +
+ " WHERE server_name_and_key_id IN ($1)"
+
+const upsertServerKeysSQL = "" +
+ "INSERT INTO keydb_server_keys (server_name, server_key_id," +
+ " server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT (server_name, server_key_id)" +
+ " DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
+
+type serverKeyStatements struct {
+ bulkSelectServerKeysStmt *sql.Stmt
+ upsertServerKeysStmt *sql.Stmt
+}
+
+func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(serverKeysSchema)
+ if err != nil {
+ return
+ }
+ if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerKeysSQL); err != nil {
+ return
+ }
+ if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *serverKeyStatements) bulkSelectServerKeys(
+ ctx context.Context,
+ requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
+) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
+ var nameAndKeyIDs []string
+ for request := range requests {
+ nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
+ }
+ stmt := s.bulkSelectServerKeysStmt
+ rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
+ for rows.Next() {
+ var serverName string
+ var keyID string
+ var key string
+ var validUntilTS int64
+ var expiredTS int64
+ if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
+ return nil, err
+ }
+ r := gomatrixserverlib.PublicKeyLookupRequest{
+ ServerName: gomatrixserverlib.ServerName(serverName),
+ KeyID: gomatrixserverlib.KeyID(keyID),
+ }
+ vk := gomatrixserverlib.VerifyKey{}
+ err = vk.Key.Decode(key)
+ if err != nil {
+ return nil, err
+ }
+ results[r] = gomatrixserverlib.PublicKeyLookupResult{
+ VerifyKey: vk,
+ ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
+ ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
+ }
+ }
+ return results, nil
+}
+
+func (s *serverKeyStatements) upsertServerKeys(
+ ctx context.Context,
+ request gomatrixserverlib.PublicKeyLookupRequest,
+ key gomatrixserverlib.PublicKeyLookupResult,
+) error {
+ _, err := s.upsertServerKeysStmt.ExecContext(
+ ctx,
+ string(request.ServerName),
+ string(request.KeyID),
+ nameAndKeyID(request),
+ key.ValidUntilTS,
+ key.ExpiredTS,
+ key.Key.Encode(),
+ )
+ return err
+}
+
+func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
+ return string(request.ServerName) + "\x1F" + string(request.KeyID)
+}
diff --git a/common/partition_offset_table.go b/common/partition_offset_table.go
index 6955ac36..6bc066a6 100644
--- a/common/partition_offset_table.go
+++ b/common/partition_offset_table.go
@@ -29,7 +29,7 @@ CREATE TABLE IF NOT EXISTS ${prefix}_partition_offsets (
partition INTEGER NOT NULL,
-- The 64-bit offset.
partition_offset BIGINT NOT NULL,
- CONSTRAINT ${prefix}_topic_partition_unique UNIQUE (topic, partition)
+ UNIQUE (topic, partition)
);
`
@@ -38,7 +38,7 @@ const selectPartitionOffsetsSQL = "" +
const upsertPartitionOffsetsSQL = "" +
"INSERT INTO ${prefix}_partition_offsets (topic, partition, partition_offset) VALUES ($1, $2, $3)" +
- " ON CONFLICT ON CONSTRAINT ${prefix}_topic_partition_unique" +
+ " ON CONFLICT (topic, partition)" +
" DO UPDATE SET partition_offset = $3"
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
diff --git a/common/sql.go b/common/sql.go
index 7ac9ac14..97593020 100644
--- a/common/sql.go
+++ b/common/sql.go
@@ -16,6 +16,7 @@ package common
import (
"database/sql"
+ "fmt"
"github.com/lib/pq"
)
@@ -30,11 +31,13 @@ type Transaction interface {
// EndTransaction ends a transaction.
// If the transaction succeeded then it is committed, otherwise it is rolledback.
-func EndTransaction(txn Transaction, succeeded *bool) {
+// You MUST check the error returned from this function to be sure that the transaction
+// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
+func EndTransaction(txn Transaction, succeeded *bool) error {
if *succeeded {
- txn.Commit() // nolint: errcheck
+ return txn.Commit() // nolint: errcheck
} else {
- txn.Rollback() // nolint: errcheck
+ return txn.Rollback() // nolint: errcheck
}
}
@@ -47,7 +50,12 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
return
}
succeeded := false
- defer EndTransaction(txn, &succeeded)
+ defer func() {
+ err2 := EndTransaction(txn, &succeeded)
+ if err == nil && err2 != nil { // failed to commit/rollback
+ err = err2
+ }
+ }()
err = fn(txn)
if err != nil {
@@ -74,3 +82,20 @@ func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
+
+// Hack of the century
+func QueryVariadic(count int) string {
+ return QueryVariadicOffset(count, 0)
+}
+
+func QueryVariadicOffset(count, offset int) string {
+ str := "("
+ for i := 0; i < count; i++ {
+ str += fmt.Sprintf("$%d", i+offset+1)
+ if i < (count - 1) {
+ str += ", "
+ }
+ }
+ str += ")"
+ return str
+}
diff --git a/docker/Dockerfile b/docker/Dockerfile
index c88b7761..29b27dde 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -1,9 +1,13 @@
+<<<<<<< HEAD
+FROM docker.io/golang:1.13.7-alpine3.11
+=======
FROM docker.io/golang:1.13.6-alpine
+>>>>>>> master
RUN mkdir /build
WORKDIR /build
-RUN apk --update --no-cache add openssl bash git
+RUN apk --update --no-cache add openssl bash git build-base
CMD ["bash", "docker/build.sh"]
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index 9cf67457..d738ed3f 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -1,13 +1,21 @@
version: "3.4"
services:
+ riot:
+ image: vectorim/riot-web
+ networks:
+ - internal
+ ports:
+ - "8500:80"
+
monolith:
container_name: dendrite_monolith
hostname: monolith
- entrypoint: ["bash", "./docker/services/monolith.sh"]
+ entrypoint: ["bash", "./docker/services/monolith.sh", "--config", "/etc/dendrite/dendrite.yaml"]
build: ./
volumes:
- ..:/build
- ./build/bin:/build/bin
+ - ../cfg:/etc/dendrite
networks:
- internal
depends_on:
diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go
index 53851bc5..ef57da88 100644
--- a/federationapi/federationapi.go
+++ b/federationapi/federationapi.go
@@ -32,8 +32,8 @@ import (
// FederationAPI component.
func SetupFederationAPIComponent(
base *basecomponent.BaseDendrite,
- accountsDB *accounts.Database,
- deviceDB *devices.Database,
+ accountsDB accounts.Database,
+ deviceDB devices.Database,
federation *gomatrixserverlib.FederationClient,
keyRing *gomatrixserverlib.KeyRing,
aliasAPI roomserverAPI.RoomserverAliasAPI,
diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go
index ba8af7a9..78021c12 100644
--- a/federationapi/routing/devices.go
+++ b/federationapi/routing/devices.go
@@ -30,7 +30,7 @@ type userDevicesResponse struct {
// GetUserDevices for the given user id
func GetUserDevices(
req *http.Request,
- deviceDB *devices.Database,
+ deviceDB devices.Database,
userID string,
) util.JSONResponse {
localpart, err := userutil.ParseUsernameParam(userID, nil)
diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go
index 3be729c2..31b7a343 100644
--- a/federationapi/routing/profile.go
+++ b/federationapi/routing/profile.go
@@ -30,7 +30,7 @@ import (
// GetProfile implements GET /_matrix/federation/v1/query/profile
func GetProfile(
httpReq *http.Request,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
cfg *config.Dendrite,
asAPI appserviceAPI.AppServiceQueryAPI,
) util.JSONResponse {
diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go
index 13ed24f3..3b119301 100644
--- a/federationapi/routing/routing.go
+++ b/federationapi/routing/routing.go
@@ -51,8 +51,8 @@ func Setup(
federationSenderAPI federationSenderAPI.FederationSenderQueryAPI,
keys gomatrixserverlib.KeyRing,
federation *gomatrixserverlib.FederationClient,
- accountDB *accounts.Database,
- deviceDB *devices.Database,
+ accountDB accounts.Database,
+ deviceDB devices.Database,
) {
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()
diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go
index 06e00eea..a22685f2 100644
--- a/federationapi/routing/threepid.go
+++ b/federationapi/routing/threepid.go
@@ -61,7 +61,7 @@ func CreateInvitesFrom3PIDInvites(
req *http.Request, queryAPI roomserverAPI.RoomserverQueryAPI,
asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
producer *producers.RoomserverProducer, federation *gomatrixserverlib.FederationClient,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
) util.JSONResponse {
var body invites
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
@@ -174,7 +174,7 @@ func createInviteFrom3PIDInvite(
ctx context.Context, queryAPI roomserverAPI.RoomserverQueryAPI,
asAPI appserviceAPI.AppServiceQueryAPI, cfg *config.Dendrite,
inv invite, federation *gomatrixserverlib.FederationClient,
- accountDB *accounts.Database,
+ accountDB accounts.Database,
) (*gomatrixserverlib.Event, error) {
_, server, err := gomatrixserverlib.SplitID('@', inv.MXID)
if err != nil {
diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go
index c60f6dc5..d97b5d29 100644
--- a/federationsender/storage/postgres/storage.go
+++ b/federationsender/storage/postgres/storage.go
@@ -87,7 +87,7 @@ func (d *Database) UpdateRoom(
return nil
}
- if lastSentEventID != oldEventID {
+ if lastSentEventID != "" && lastSentEventID != oldEventID {
return types.EventIDMismatchError{
DatabaseID: lastSentEventID, RoomServerID: oldEventID,
}
diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go
new file mode 100644
index 00000000..1437a062
--- /dev/null
+++ b/federationsender/storage/sqlite3/joined_hosts_table.go
@@ -0,0 +1,139 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/federationsender/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const joinedHostsSchema = `
+-- The joined_hosts table stores a list of m.room.member event ids in the
+-- current state for each room where the membership is "join".
+-- There will be an entry for every user that is joined to the room.
+CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
+ -- The string ID of the room.
+ room_id TEXT NOT NULL,
+ -- The event ID of the m.room.member join event.
+ event_id TEXT NOT NULL,
+ -- The domain part of the user ID the m.room.member event is for.
+ server_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
+ ON federationsender_joined_hosts (event_id);
+
+CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
+ ON federationsender_joined_hosts (room_id)
+`
+
+const insertJoinedHostsSQL = "" +
+ "INSERT INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
+ " VALUES ($1, $2, $3)"
+
+const deleteJoinedHostsSQL = "" +
+ "DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
+
+const selectJoinedHostsSQL = "" +
+ "SELECT event_id, server_name FROM federationsender_joined_hosts" +
+ " WHERE room_id = $1"
+
+type joinedHostsStatements struct {
+ insertJoinedHostsStmt *sql.Stmt
+ deleteJoinedHostsStmt *sql.Stmt
+ selectJoinedHostsStmt *sql.Stmt
+}
+
+func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(joinedHostsSchema)
+ if err != nil {
+ return
+ }
+ if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
+ return
+ }
+ if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
+ return
+ }
+ if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *joinedHostsStatements) insertJoinedHosts(
+ ctx context.Context,
+ txn *sql.Tx,
+ roomID, eventID string,
+ serverName gomatrixserverlib.ServerName,
+) error {
+ stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
+ _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
+ return err
+}
+
+func (s *joinedHostsStatements) deleteJoinedHosts(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) error {
+ for _, eventID := range eventIDs {
+ stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
+ if _, err := stmt.ExecContext(ctx, eventID); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *joinedHostsStatements) selectJoinedHostsWithTx(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) ([]types.JoinedHost, error) {
+ stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
+ return joinedHostsFromStmt(ctx, stmt, roomID)
+}
+
+func (s *joinedHostsStatements) selectJoinedHosts(
+ ctx context.Context, roomID string,
+) ([]types.JoinedHost, error) {
+ return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
+}
+
+func joinedHostsFromStmt(
+ ctx context.Context, stmt *sql.Stmt, roomID string,
+) ([]types.JoinedHost, error) {
+ rows, err := stmt.QueryContext(ctx, roomID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ var result []types.JoinedHost
+ for rows.Next() {
+ var eventID, serverName string
+ if err = rows.Scan(&eventID, &serverName); err != nil {
+ return nil, err
+ }
+ result = append(result, types.JoinedHost{
+ MemberEventID: eventID,
+ ServerName: gomatrixserverlib.ServerName(serverName),
+ })
+ }
+
+ return result, nil
+}
diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go
new file mode 100644
index 00000000..6361400d
--- /dev/null
+++ b/federationsender/storage/sqlite3/room_table.go
@@ -0,0 +1,101 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+)
+
+const roomSchema = `
+CREATE TABLE IF NOT EXISTS federationsender_rooms (
+ -- The string ID of the room
+ room_id TEXT PRIMARY KEY,
+ -- The most recent event state by the room server.
+ -- We can use this to tell if our view of the room state has become
+ -- desynchronised.
+ last_event_id TEXT NOT NULL
+);`
+
+const insertRoomSQL = "" +
+ "INSERT INTO federationsender_rooms (room_id, last_event_id) VALUES ($1, '')" +
+ " ON CONFLICT DO NOTHING"
+
+const selectRoomForUpdateSQL = "" +
+ "SELECT last_event_id FROM federationsender_rooms WHERE room_id = $1"
+
+const updateRoomSQL = "" +
+ "UPDATE federationsender_rooms SET last_event_id = $2 WHERE room_id = $1"
+
+type roomStatements struct {
+ insertRoomStmt *sql.Stmt
+ selectRoomForUpdateStmt *sql.Stmt
+ updateRoomStmt *sql.Stmt
+}
+
+func (s *roomStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(roomSchema)
+ if err != nil {
+ return
+ }
+
+ if s.insertRoomStmt, err = db.Prepare(insertRoomSQL); err != nil {
+ return
+ }
+ if s.selectRoomForUpdateStmt, err = db.Prepare(selectRoomForUpdateSQL); err != nil {
+ return
+ }
+ if s.updateRoomStmt, err = db.Prepare(updateRoomSQL); err != nil {
+ return
+ }
+ return
+}
+
+// insertRoom inserts the room if it didn't already exist.
+// If the room didn't exist then last_event_id is set to the empty string.
+func (s *roomStatements) insertRoom(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) error {
+ _, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
+ return err
+}
+
+// selectRoomForUpdate locks the row for the room and returns the last_event_id.
+// The row must already exist in the table. Callers can ensure that the row
+// exists by calling insertRoom first.
+func (s *roomStatements) selectRoomForUpdate(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (string, error) {
+ var lastEventID string
+ stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
+ err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
+ if err != nil {
+ return "", err
+ }
+ return lastEventID, nil
+}
+
+// updateRoom updates the last_event_id for the room. selectRoomForUpdate should
+// have already been called earlier within the transaction.
+func (s *roomStatements) updateRoom(
+ ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
+) error {
+ stmt := common.TxStmt(txn, s.updateRoomStmt)
+ _, err := stmt.ExecContext(ctx, roomID, lastEventID)
+ return err
+}
diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go
new file mode 100644
index 00000000..f9cfaa99
--- /dev/null
+++ b/federationsender/storage/sqlite3/storage.go
@@ -0,0 +1,124 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ _ "github.com/mattn/go-sqlite3"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/federationsender/types"
+)
+
+// Database stores information needed by the federation sender
+type Database struct {
+ joinedHostsStatements
+ roomStatements
+ common.PartitionOffsetStatements
+ db *sql.DB
+}
+
+// NewDatabase opens a new database
+func NewDatabase(dataSourceName string) (*Database, error) {
+ var result Database
+ var err error
+ if result.db, err = sql.Open("sqlite3", dataSourceName); err != nil {
+ return nil, err
+ }
+ if err = result.prepare(); err != nil {
+ return nil, err
+ }
+ return &result, nil
+}
+
+func (d *Database) prepare() error {
+ var err error
+
+ if err = d.joinedHostsStatements.prepare(d.db); err != nil {
+ return err
+ }
+
+ if err = d.roomStatements.prepare(d.db); err != nil {
+ return err
+ }
+
+ return d.PartitionOffsetStatements.Prepare(d.db, "federationsender")
+}
+
+// UpdateRoom updates the joined hosts for a room and returns what the joined
+// hosts were before the update, or nil if this was a duplicate message.
+// This is called when we receive a message from kafka, so we pass in
+// oldEventID and newEventID to check that we haven't missed any messages or
+// this isn't a duplicate message.
+func (d *Database) UpdateRoom(
+ ctx context.Context,
+ roomID, oldEventID, newEventID string,
+ addHosts []types.JoinedHost,
+ removeHosts []string,
+) (joinedHosts []types.JoinedHost, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ err = d.insertRoom(ctx, txn, roomID)
+ if err != nil {
+ return err
+ }
+
+ lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
+ if err != nil {
+ return err
+ }
+
+ if lastSentEventID == newEventID {
+ // We've handled this message before, so let's just ignore it.
+ // We can only get a duplicate for the last message we processed,
+ // so its enough just to compare the newEventID with lastSentEventID
+ return nil
+ }
+
+ if lastSentEventID != "" && lastSentEventID != oldEventID {
+ return types.EventIDMismatchError{
+ DatabaseID: lastSentEventID, RoomServerID: oldEventID,
+ }
+ }
+
+ joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID)
+ if err != nil {
+ return err
+ }
+
+ for _, add := range addHosts {
+ err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
+ if err != nil {
+ return err
+ }
+ }
+ if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
+ return err
+ }
+ return d.updateRoom(ctx, txn, roomID, newEventID)
+ })
+ return
+}
+
+// GetJoinedHosts returns the currently joined hosts for room,
+// as known to federationserver.
+// Returns an error if something goes wrong.
+func (d *Database) GetJoinedHosts(
+ ctx context.Context, roomID string,
+) ([]types.JoinedHost, error) {
+ return d.selectJoinedHosts(ctx, roomID)
+}
diff --git a/federationsender/storage/storage.go b/federationsender/storage/storage.go
index 4ce151c7..e83c1e9d 100644
--- a/federationsender/storage/storage.go
+++ b/federationsender/storage/storage.go
@@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/federationsender/storage/postgres"
+ "github.com/matrix-org/dendrite/federationsender/storage/sqlite3"
"github.com/matrix-org/dendrite/federationsender/types"
)
@@ -36,6 +37,8 @@ func NewDatabase(dataSourceName string) (Database, error) {
return postgres.NewDatabase(dataSourceName)
}
switch uri.Scheme {
+ case "file":
+ return sqlite3.NewDatabase(dataSourceName)
case "postgres":
return postgres.NewDatabase(dataSourceName)
default:
diff --git a/go.mod b/go.mod
index 990b839e..2d442cd0 100644
--- a/go.mod
+++ b/go.mod
@@ -1,30 +1,30 @@
module github.com/matrix-org/dendrite
require (
- github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3
+ github.com/DataDog/zstd v1.4.4 // indirect
github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
- github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 // indirect
- github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 // indirect
+ github.com/eapache/go-resiliency v1.2.0 // indirect
+ github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect
github.com/eapache/queue v1.1.0 // indirect
- github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
+ github.com/frankban/quicktest v1.7.2 // indirect
+ github.com/golang/snappy v0.0.1 // indirect
github.com/gorilla/mux v1.7.3
- github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 // indirect
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
github.com/lib/pq v1.2.0
github.com/matrix-org/dugong v0.0.0-20171220115018-ea0a4690a0d5
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5
- github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0
+ github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5
+ github.com/mattn/go-sqlite3 v2.0.2+incompatible
github.com/miekg/dns v1.1.12 // indirect
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5
github.com/opentracing/opentracing-go v1.0.2
- github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac // indirect
- github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 // indirect
+ github.com/pierrec/lz4 v2.4.1+incompatible // indirect
github.com/pkg/errors v0.8.1
github.com/prometheus/client_golang v1.2.1
- github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 // indirect
+ github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 // indirect
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0 // indirect
github.com/uber-go/atomic v1.3.0 // indirect
@@ -34,7 +34,7 @@ require (
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550
golang.org/x/net v0.0.0-20190620200207-3b0461eec859 // indirect
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6 // indirect
- gopkg.in/Shopify/sarama.v1 v1.11.0
+ gopkg.in/Shopify/sarama.v1 v1.20.1
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/h2non/bimg.v1 v1.0.18
gopkg.in/yaml.v2 v2.2.2
diff --git a/go.sum b/go.sum
index 42a145d6..7c8732f6 100644
--- a/go.sum
+++ b/go.sum
@@ -1,5 +1,5 @@
-github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3 h1:j6BAEHYn1kUyW2j7kY0mOJ/R8A0qWwXpvUAEHGemm/g=
-github.com/Shopify/sarama v0.0.0-20170127151855-574d3147eee3/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
+github.com/DataDog/zstd v1.4.4 h1:+IawcoXhCBylN7ccwdwf8LOH2jKq7NavGpEPanrlTzE=
+github.com/DataDog/zstd v1.4.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
@@ -18,13 +18,15 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42 h1:f8ERmXYuaC+kCSv2w+y3rBK/oVu6If4DEm3jywJJ0hc=
-github.com/eapache/go-resiliency v0.0.0-20160104191539-b86b1ec0dd42/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
-github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934 h1:oGLoaVIefp3tiOgi7+KInR/nNPvEpPM6GFo+El7fd14=
-github.com/eapache/go-xerial-snappy v0.0.0-20160609142408-bb955e01b934/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
+github.com/eapache/go-resiliency v1.2.0 h1:v7g92e/KSN71Rq7vSThKaWIq68fL4YHvWyiUKorFR1Q=
+github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
+github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
+github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k=
+github.com/frankban/quicktest v1.7.2 h1:2QxQoC1TS09S7fhCPsrvqYdvP1H5M1P1ih5ABm3BTYk=
+github.com/frankban/quicktest v1.7.2/go.mod h1:jaStnuzAqU1AJdCO0l53JDCJrVDKcS03DbaAcR7Ks/o=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
@@ -35,11 +37,13 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
-github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
+github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
+github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg=
+github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
@@ -48,8 +52,6 @@ github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplb
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
-github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6 h1:KAZ1BW2TCmT6PRihDPpocIy1QTtsAsrx6TneU/4+CMg=
-github.com/klauspost/crc32 v0.0.0-20161016154125-cb6bfca970f6/go.mod h1:+ZoRqAPRLkC4NPOvfYeR5KNOrY6TD+/sAC3HXPZgDYg=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
@@ -69,10 +71,12 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5 h1:kmRjpmFOenVpOaV/DRlo9p6z/IbOKlUC+hhKsAAh8Qg=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200124100636-0c2ec91d1df5/go.mod h1:FsKa2pWE/bpQql9H7U4boOPXFoJX/QcqaZZ6ijLkaZI=
-github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0 h1:p7WTwG+aXM86+yVrYAiCMW3ZHSmotVvuRbjtt3jC+4A=
-github.com/matrix-org/naffka v0.0.0-20171115094957-662bfd0841d0/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
+github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1 h1:osLoFdOy+ChQqVUn2PeTDETFftVkl4w9t/OW18g3lnk=
+github.com/matrix-org/naffka v0.0.0-20200127221512-0716baaabaf1/go.mod h1:cXoYQIENbdWIQHt1SyCo6Bl3C3raHwJ0wgVrXHSqf+A=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5 h1:W7l5CP4V7wPyPb4tYE11dbmeAOwtFQBTW0rf4OonOS8=
github.com/matrix-org/util v0.0.0-20171127121716-2e2df66af2f5/go.mod h1:lePuOiXLNDott7NZfnQvJk0lAZ5HgvIuWGhel6J+RLA=
+github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U=
+github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.1.4 h1:rCMZsU2ScVSYcAsOXgmC6+AKOK+6pmQTOcw03nfwYV0=
@@ -90,10 +94,8 @@ github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/R
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg=
github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o=
-github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac h1:tKcxwAA5OHUQjL6sWsuCIcP9OnzN+RwKfvomtIOsfy8=
-github.com/pierrec/lz4 v0.0.0-20161206202305-5c9560bfa9ac/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
-github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897 h1:jp3jc/PyyTrTKjJJ6rWnhTbmo7tGgBFyG9AL5FIrO1I=
-github.com/pierrec/xxHash v0.0.0-20160112165351-5a004441f897/go.mod h1:w2waW5Zoa/Wc4Yqe0wgrIYAGKqRMf7czn2HNKXmuL+I=
+github.com/pierrec/lz4 v2.4.1+incompatible h1:mFe7ttWaflA46Mhqh+jUfjp2qTbPYxLB2/OyBppH9dg=
+github.com/pierrec/lz4 v2.4.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -114,8 +116,8 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R
github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/prometheus/procfs v0.0.5 h1:3+auTFlqw+ZaQYJARz6ArODtkaIwtvBTx3N2NehQlL8=
github.com/prometheus/procfs v0.0.5/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ=
-github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5 h1:gwcdIpH6NU2iF8CmcqD+CP6+1CkRBOhHaPR+iu6raBY=
-github.com/rcrowley/go-metrics v0.0.0-20161128210544-1f30fe9094a5/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
+github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563 h1:dY6ETXrvDG7Sa4vE8ZQG4yqWg6UnOcbqTAahkV813vQ=
+github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME=
github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
@@ -172,8 +174,8 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY=
golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-gopkg.in/Shopify/sarama.v1 v1.11.0 h1:/3kaCyeYaPbr59IBjeqhIcUOB1vXlIVqXAYa5g5C5F0=
-gopkg.in/Shopify/sarama.v1 v1.11.0/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
+gopkg.in/Shopify/sarama.v1 v1.20.1 h1:Gi09A3fJXm0Jgt8kuKZ8YK+r60GfYn7MQuEmI3oq6hE=
+gopkg.in/Shopify/sarama.v1 v1.20.1/go.mod h1:AxnvoaevB2nBjNK17cG61A3LleFcWFwVBHBt+cot4Oc=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go
index 46d1c328..f2e614c1 100644
--- a/mediaapi/mediaapi.go
+++ b/mediaapi/mediaapi.go
@@ -27,7 +27,7 @@ import (
// component.
func SetupMediaAPIComponent(
base *basecomponent.BaseDendrite,
- deviceDB *devices.Database,
+ deviceDB devices.Database,
) {
mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI))
if err != nil {
diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go
index dcc6ac06..71dad19b 100644
--- a/mediaapi/routing/routing.go
+++ b/mediaapi/routing/routing.go
@@ -44,7 +44,7 @@ func Setup(
apiMux *mux.Router,
cfg *config.Dendrite,
db storage.Database,
- deviceDB *devices.Database,
+ deviceDB devices.Database,
client *gomatrixserverlib.Client,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
diff --git a/publicroomsapi/publicroomsapi.go b/publicroomsapi/publicroomsapi.go
index 181966d3..1e2a3f9b 100644
--- a/publicroomsapi/publicroomsapi.go
+++ b/publicroomsapi/publicroomsapi.go
@@ -28,7 +28,7 @@ import (
// component.
func SetupPublicRoomsAPIComponent(
base *basecomponent.BaseDendrite,
- deviceDB *devices.Database,
+ deviceDB devices.Database,
rsQueryAPI roomserverAPI.RoomserverQueryAPI,
) {
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))
diff --git a/publicroomsapi/routing/routing.go b/publicroomsapi/routing/routing.go
index 3d2d2ac0..1953e04f 100644
--- a/publicroomsapi/routing/routing.go
+++ b/publicroomsapi/routing/routing.go
@@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
// Due to Setup being used to call many other functions, a gocyclo nolint is
// applied:
// nolint: gocyclo
-func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB storage.Database) {
+func Setup(apiMux *mux.Router, deviceDB devices.Database, publicRoomsDB storage.Database) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{
diff --git a/publicroomsapi/storage/sqlite3/prepare.go b/publicroomsapi/storage/sqlite3/prepare.go
new file mode 100644
index 00000000..482dfa2b
--- /dev/null
+++ b/publicroomsapi/storage/sqlite3/prepare.go
@@ -0,0 +1,36 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "database/sql"
+)
+
+// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
+type statementList []struct {
+ statement **sql.Stmt
+ sql string
+}
+
+// prepare the SQL for each statement in the list and assign the result to the prepared statement.
+func (s statementList) prepare(db *sql.DB) (err error) {
+ for _, statement := range s {
+ if *statement.statement, err = db.Prepare(statement.sql); err != nil {
+ return
+ }
+ }
+ return
+}
diff --git a/publicroomsapi/storage/sqlite3/public_rooms_table.go b/publicroomsapi/storage/sqlite3/public_rooms_table.go
new file mode 100644
index 00000000..06c74a33
--- /dev/null
+++ b/publicroomsapi/storage/sqlite3/public_rooms_table.go
@@ -0,0 +1,277 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/publicroomsapi/types"
+)
+
+var editableAttributes = []string{
+ "aliases",
+ "canonical_alias",
+ "name",
+ "topic",
+ "world_readable",
+ "guest_can_join",
+ "avatar_url",
+ "visibility",
+}
+
+const publicRoomsSchema = `
+-- Stores all of the rooms with data needed to create the server's room directory
+CREATE TABLE IF NOT EXISTS publicroomsapi_public_rooms(
+ -- The room's ID
+ room_id TEXT NOT NULL PRIMARY KEY,
+ -- Number of joined members in the room
+ joined_members INTEGER NOT NULL DEFAULT 0,
+ -- Aliases of the room (empty array if none)
+ aliases TEXT[] NOT NULL DEFAULT '{}'::TEXT[],
+ -- Canonical alias of the room (empty string if none)
+ canonical_alias TEXT NOT NULL DEFAULT '',
+ -- Name of the room (empty string if none)
+ name TEXT NOT NULL DEFAULT '',
+ -- Topic of the room (empty string if none)
+ topic TEXT NOT NULL DEFAULT '',
+ -- Is the room world readable?
+ world_readable BOOLEAN NOT NULL DEFAULT false,
+ -- Can guest join the room?
+ guest_can_join BOOLEAN NOT NULL DEFAULT false,
+ -- URL of the room avatar (empty string if none)
+ avatar_url TEXT NOT NULL DEFAULT '',
+ -- Visibility of the room: true means the room is publicly visible, false
+ -- means the room is private
+ visibility BOOLEAN NOT NULL DEFAULT false
+);
+`
+
+const countPublicRoomsSQL = "" +
+ "SELECT COUNT(*) FROM publicroomsapi_public_rooms" +
+ " WHERE visibility = true"
+
+const selectPublicRoomsSQL = "" +
+ "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
+ " FROM publicroomsapi_public_rooms WHERE visibility = true" +
+ " ORDER BY joined_members DESC" +
+ " OFFSET $1"
+
+const selectPublicRoomsWithLimitSQL = "" +
+ "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
+ " FROM publicroomsapi_public_rooms WHERE visibility = true" +
+ " ORDER BY joined_members DESC" +
+ " OFFSET $1 LIMIT $2"
+
+const selectPublicRoomsWithFilterSQL = "" +
+ "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
+ " FROM publicroomsapi_public_rooms" +
+ " WHERE visibility = true" +
+ " AND (LOWER(name) LIKE LOWER($1)" +
+ " OR LOWER(topic) LIKE LOWER($1)" +
+ " OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
+ " ORDER BY joined_members DESC" +
+ " OFFSET $2"
+
+const selectPublicRoomsWithLimitAndFilterSQL = "" +
+ "SELECT room_id, joined_members, aliases, canonical_alias, name, topic, world_readable, guest_can_join, avatar_url" +
+ " FROM publicroomsapi_public_rooms" +
+ " WHERE visibility = true" +
+ " AND (LOWER(name) LIKE LOWER($1)" +
+ " OR LOWER(topic) LIKE LOWER($1)" +
+ " OR LOWER(ARRAY_TO_STRING(aliases, ',')) LIKE LOWER($1))" +
+ " ORDER BY joined_members DESC" +
+ " OFFSET $2 LIMIT $3"
+
+const selectRoomVisibilitySQL = "" +
+ "SELECT visibility FROM publicroomsapi_public_rooms" +
+ " WHERE room_id = $1"
+
+const insertNewRoomSQL = "" +
+ "INSERT INTO publicroomsapi_public_rooms(room_id)" +
+ " VALUES ($1)"
+
+const incrementJoinedMembersInRoomSQL = "" +
+ "UPDATE publicroomsapi_public_rooms" +
+ " SET joined_members = joined_members + 1" +
+ " WHERE room_id = $1"
+
+const decrementJoinedMembersInRoomSQL = "" +
+ "UPDATE publicroomsapi_public_rooms" +
+ " SET joined_members = joined_members - 1" +
+ " WHERE room_id = $1"
+
+const updateRoomAttributeSQL = "" +
+ "UPDATE publicroomsapi_public_rooms" +
+ " SET %s = $1" +
+ " WHERE room_id = $2"
+
+type publicRoomsStatements struct {
+ countPublicRoomsStmt *sql.Stmt
+ selectPublicRoomsStmt *sql.Stmt
+ selectPublicRoomsWithLimitStmt *sql.Stmt
+ selectPublicRoomsWithFilterStmt *sql.Stmt
+ selectPublicRoomsWithLimitAndFilterStmt *sql.Stmt
+ selectRoomVisibilityStmt *sql.Stmt
+ insertNewRoomStmt *sql.Stmt
+ incrementJoinedMembersInRoomStmt *sql.Stmt
+ decrementJoinedMembersInRoomStmt *sql.Stmt
+ updateRoomAttributeStmts map[string]*sql.Stmt
+}
+
+func (s *publicRoomsStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(publicRoomsSchema)
+ if err != nil {
+ return
+ }
+
+ stmts := statementList{
+ {&s.countPublicRoomsStmt, countPublicRoomsSQL},
+ {&s.selectPublicRoomsStmt, selectPublicRoomsSQL},
+ {&s.selectPublicRoomsWithLimitStmt, selectPublicRoomsWithLimitSQL},
+ {&s.selectPublicRoomsWithFilterStmt, selectPublicRoomsWithFilterSQL},
+ {&s.selectPublicRoomsWithLimitAndFilterStmt, selectPublicRoomsWithLimitAndFilterSQL},
+ {&s.selectRoomVisibilityStmt, selectRoomVisibilitySQL},
+ {&s.insertNewRoomStmt, insertNewRoomSQL},
+ {&s.incrementJoinedMembersInRoomStmt, incrementJoinedMembersInRoomSQL},
+ {&s.decrementJoinedMembersInRoomStmt, decrementJoinedMembersInRoomSQL},
+ }
+
+ if err = stmts.prepare(db); err != nil {
+ return
+ }
+
+ s.updateRoomAttributeStmts = make(map[string]*sql.Stmt)
+ for _, editable := range editableAttributes {
+ stmt := fmt.Sprintf(updateRoomAttributeSQL, editable)
+ if s.updateRoomAttributeStmts[editable], err = db.Prepare(stmt); err != nil {
+ return
+ }
+ }
+
+ return
+}
+
+func (s *publicRoomsStatements) countPublicRooms(ctx context.Context) (nb int64, err error) {
+ err = s.countPublicRoomsStmt.QueryRowContext(ctx).Scan(&nb)
+ return
+}
+
+func (s *publicRoomsStatements) selectPublicRooms(
+ ctx context.Context, offset int64, limit int16, filter string,
+) ([]types.PublicRoom, error) {
+ var rows *sql.Rows
+ var err error
+
+ if len(filter) > 0 {
+ pattern := "%" + filter + "%"
+ if limit == 0 {
+ rows, err = s.selectPublicRoomsWithFilterStmt.QueryContext(
+ ctx, pattern, offset,
+ )
+ } else {
+ rows, err = s.selectPublicRoomsWithLimitAndFilterStmt.QueryContext(
+ ctx, pattern, offset, limit,
+ )
+ }
+ } else {
+ if limit == 0 {
+ rows, err = s.selectPublicRoomsStmt.QueryContext(ctx, offset)
+ } else {
+ rows, err = s.selectPublicRoomsWithLimitStmt.QueryContext(
+ ctx, offset, limit,
+ )
+ }
+ }
+
+ if err != nil {
+ return []types.PublicRoom{}, nil
+ }
+
+ rooms := []types.PublicRoom{}
+ for rows.Next() {
+ var r types.PublicRoom
+ var aliases pq.StringArray
+
+ err = rows.Scan(
+ &r.RoomID, &r.NumJoinedMembers, &aliases, &r.CanonicalAlias,
+ &r.Name, &r.Topic, &r.WorldReadable, &r.GuestCanJoin, &r.AvatarURL,
+ )
+ if err != nil {
+ return rooms, err
+ }
+
+ r.Aliases = aliases
+
+ rooms = append(rooms, r)
+ }
+
+ return rooms, nil
+}
+
+func (s *publicRoomsStatements) selectRoomVisibility(
+ ctx context.Context, roomID string,
+) (v bool, err error) {
+ err = s.selectRoomVisibilityStmt.QueryRowContext(ctx, roomID).Scan(&v)
+ return
+}
+
+func (s *publicRoomsStatements) insertNewRoom(
+ ctx context.Context, roomID string,
+) error {
+ _, err := s.insertNewRoomStmt.ExecContext(ctx, roomID)
+ return err
+}
+
+func (s *publicRoomsStatements) incrementJoinedMembersInRoom(
+ ctx context.Context, roomID string,
+) error {
+ _, err := s.incrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
+ return err
+}
+
+func (s *publicRoomsStatements) decrementJoinedMembersInRoom(
+ ctx context.Context, roomID string,
+) error {
+ _, err := s.decrementJoinedMembersInRoomStmt.ExecContext(ctx, roomID)
+ return err
+}
+
+func (s *publicRoomsStatements) updateRoomAttribute(
+ ctx context.Context, attrName string, attrValue attributeValue, roomID string,
+) error {
+ stmt, isEditable := s.updateRoomAttributeStmts[attrName]
+
+ if !isEditable {
+ return errors.New("Cannot edit " + attrName)
+ }
+
+ var value interface{}
+ switch v := attrValue.(type) {
+ case []string:
+ value = pq.StringArray(v)
+ case bool, string:
+ value = attrValue
+ default:
+ return errors.New("Unsupported attribute type, must be bool, string or []string")
+ }
+
+ _, err := stmt.ExecContext(ctx, value, roomID)
+ return err
+}
diff --git a/publicroomsapi/storage/sqlite3/storage.go b/publicroomsapi/storage/sqlite3/storage.go
new file mode 100644
index 00000000..dcb8920f
--- /dev/null
+++ b/publicroomsapi/storage/sqlite3/storage.go
@@ -0,0 +1,256 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ _ "github.com/mattn/go-sqlite3"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/publicroomsapi/types"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// PublicRoomsServerDatabase represents a public rooms server database.
+type PublicRoomsServerDatabase struct {
+ db *sql.DB
+ common.PartitionOffsetStatements
+ statements publicRoomsStatements
+}
+
+type attributeValue interface{}
+
+// NewPublicRoomsServerDatabase creates a new public rooms server database.
+func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) {
+ var db *sql.DB
+ var err error
+ if db, err = sql.Open("sqlite3", dataSourceName); err != nil {
+ return nil, err
+ }
+ storage := PublicRoomsServerDatabase{
+ db: db,
+ }
+ if err = storage.PartitionOffsetStatements.Prepare(db, "publicroomsapi"); err != nil {
+ return nil, err
+ }
+ if err = storage.statements.prepare(db); err != nil {
+ return nil, err
+ }
+ return &storage, nil
+}
+
+// GetRoomVisibility returns the room visibility as a boolean: true if the room
+// is publicly visible, false if not.
+// Returns an error if the retrieval failed.
+func (d *PublicRoomsServerDatabase) GetRoomVisibility(
+ ctx context.Context, roomID string,
+) (bool, error) {
+ return d.statements.selectRoomVisibility(ctx, roomID)
+}
+
+// SetRoomVisibility updates the visibility attribute of a room. This attribute
+// must be set to true if the room is publicly visible, false if not.
+// Returns an error if the update failed.
+func (d *PublicRoomsServerDatabase) SetRoomVisibility(
+ ctx context.Context, visible bool, roomID string,
+) error {
+ return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID)
+}
+
+// CountPublicRooms returns the number of room set as publicly visible on the server.
+// Returns an error if the retrieval failed.
+func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) {
+ return d.statements.countPublicRooms(ctx)
+}
+
+// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number
+// of joined members. This array can be limited by a given number of elements, and offset by a given value.
+// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all
+// the rooms set as publicly visible on the server.
+// Returns an error if the retrieval failed.
+func (d *PublicRoomsServerDatabase) GetPublicRooms(
+ ctx context.Context, offset int64, limit int16, filter string,
+) ([]types.PublicRoom, error) {
+ return d.statements.selectPublicRooms(ctx, offset, limit, filter)
+}
+
+// UpdateRoomFromEvents iterate over a slice of state events and call
+// UpdateRoomFromEvent on each of them to update the database representation of
+// the rooms updated by each event.
+// The slice of events to remove is used to update the number of joined members
+// for the room in the database.
+// If the update triggered by one of the events failed, aborts the process and
+// returns an error.
+func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents(
+ ctx context.Context,
+ eventsToAdd []gomatrixserverlib.Event,
+ eventsToRemove []gomatrixserverlib.Event,
+) error {
+ for _, event := range eventsToAdd {
+ if err := d.UpdateRoomFromEvent(ctx, event); err != nil {
+ return err
+ }
+ }
+
+ for _, event := range eventsToRemove {
+ if event.Type() == "m.room.member" {
+ if err := d.updateNumJoinedUsers(ctx, event, true); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// UpdateRoomFromEvent updates the database representation of a room from a Matrix event, by
+// checking the event's type to know which attribute to change and using the event's content
+// to define the new value of the attribute.
+// If the event doesn't match with any property used to compute the public room directory,
+// does nothing.
+// If something went wrong during the process, returns an error.
+func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent(
+ ctx context.Context, event gomatrixserverlib.Event,
+) error {
+ // Process the event according to its type
+ switch event.Type() {
+ case "m.room.create":
+ return d.statements.insertNewRoom(ctx, event.RoomID())
+ case "m.room.member":
+ return d.updateNumJoinedUsers(ctx, event, false)
+ case "m.room.aliases":
+ return d.updateRoomAliases(ctx, event)
+ case "m.room.canonical_alias":
+ var content common.CanonicalAliasContent
+ field := &(content.Alias)
+ attrName := "canonical_alias"
+ return d.updateStringAttribute(ctx, attrName, event, &content, field)
+ case "m.room.name":
+ var content common.NameContent
+ field := &(content.Name)
+ attrName := "name"
+ return d.updateStringAttribute(ctx, attrName, event, &content, field)
+ case "m.room.topic":
+ var content common.TopicContent
+ field := &(content.Topic)
+ attrName := "topic"
+ return d.updateStringAttribute(ctx, attrName, event, &content, field)
+ case "m.room.avatar":
+ var content common.AvatarContent
+ field := &(content.URL)
+ attrName := "avatar_url"
+ return d.updateStringAttribute(ctx, attrName, event, &content, field)
+ case "m.room.history_visibility":
+ var content common.HistoryVisibilityContent
+ field := &(content.HistoryVisibility)
+ attrName := "world_readable"
+ strForTrue := "world_readable"
+ return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
+ case "m.room.guest_access":
+ var content common.GuestAccessContent
+ field := &(content.GuestAccess)
+ attrName := "guest_can_join"
+ strForTrue := "can_join"
+ return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue)
+ }
+
+ // If the event type didn't match, return with no error
+ return nil
+}
+
+// updateNumJoinedUsers updates the number of joined user in the database representation
+// of a room using a given "m.room.member" Matrix event.
+// If the membership property of the event isn't "join", ignores it and returs nil.
+// If the remove parameter is set to false, increments the joined members counter in the
+// database, if set to truem decrements it.
+// Returns an error if the update failed.
+func (d *PublicRoomsServerDatabase) updateNumJoinedUsers(
+ ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool,
+) error {
+ membership, err := membershipEvent.Membership()
+ if err != nil {
+ return err
+ }
+
+ if membership != gomatrixserverlib.Join {
+ return nil
+ }
+
+ if remove {
+ return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
+ }
+ return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID())
+}
+
+// updateStringAttribute updates a given string attribute in the database
+// representation of a room using a given string data field from content of the
+// Matrix event triggering the update.
+// Returns an error if decoding the Matrix event's content or updating the attribute
+// failed.
+func (d *PublicRoomsServerDatabase) updateStringAttribute(
+ ctx context.Context, attrName string, event gomatrixserverlib.Event,
+ content interface{}, field *string,
+) error {
+ if err := json.Unmarshal(event.Content(), content); err != nil {
+ return err
+ }
+
+ return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID())
+}
+
+// updateBooleanAttribute updates a given boolean attribute in the database
+// representation of a room using a given string data field from content of the
+// Matrix event triggering the update.
+// The attribute is set to true if the field matches a given string, false if not.
+// Returns an error if decoding the Matrix event's content or updating the attribute
+// failed.
+func (d *PublicRoomsServerDatabase) updateBooleanAttribute(
+ ctx context.Context, attrName string, event gomatrixserverlib.Event,
+ content interface{}, field *string, strForTrue string,
+) error {
+ if err := json.Unmarshal(event.Content(), content); err != nil {
+ return err
+ }
+
+ var attrValue bool
+ if *field == strForTrue {
+ attrValue = true
+ } else {
+ attrValue = false
+ }
+
+ return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID())
+}
+
+// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of
+// a given room with it.
+// Returns an error if decoding the Matrix event or updating the list failed.
+func (d *PublicRoomsServerDatabase) updateRoomAliases(
+ ctx context.Context, aliasesEvent gomatrixserverlib.Event,
+) error {
+ var content common.AliasesContent
+ if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil {
+ return err
+ }
+
+ return d.statements.updateRoomAttribute(
+ ctx, "aliases", content.Aliases, aliasesEvent.RoomID(),
+ )
+}
diff --git a/roomserver/input/events.go b/roomserver/input/events.go
index 03023a4a..a3b70753 100644
--- a/roomserver/input/events.go
+++ b/roomserver/input/events.go
@@ -196,7 +196,12 @@ func processInviteEvent(
return err
}
succeeded := false
- defer common.EndTransaction(updater, &succeeded)
+ defer func() {
+ txerr := common.EndTransaction(updater, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
if updater.IsJoin() {
// If the user is joined to the room then that takes precedence over this
diff --git a/roomserver/input/latest_events.go b/roomserver/input/latest_events.go
index 7e03d544..f9fd1d5d 100644
--- a/roomserver/input/latest_events.go
+++ b/roomserver/input/latest_events.go
@@ -60,7 +60,12 @@ func updateLatestEvents(
return
}
succeeded := false
- defer common.EndTransaction(updater, &succeeded)
+ defer func() {
+ txerr := common.EndTransaction(updater, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
u := latestEventsUpdater{
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go
new file mode 100644
index 00000000..f6c83906
--- /dev/null
+++ b/roomserver/storage/sqlite3/event_json_table.go
@@ -0,0 +1,108 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const eventJSONSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_event_json (
+ event_nid INTEGER NOT NULL PRIMARY KEY,
+ event_json TEXT NOT NULL
+ );
+`
+
+const insertEventJSONSQL = `
+ INSERT INTO roomserver_event_json (event_nid, event_json) VALUES ($1, $2)
+ ON CONFLICT DO NOTHING
+`
+
+// Bulk event JSON lookup by numeric event ID.
+// Sort by the numeric event ID.
+// This means that we can use binary search to lookup by numeric event ID.
+const bulkSelectEventJSONSQL = `
+ SELECT event_nid, event_json FROM roomserver_event_json
+ WHERE event_nid IN ($1)
+ ORDER BY event_nid ASC
+`
+
+type eventJSONStatements struct {
+ db *sql.DB
+ insertEventJSONStmt *sql.Stmt
+ bulkSelectEventJSONStmt *sql.Stmt
+}
+
+func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ _, err = db.Exec(eventJSONSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertEventJSONStmt, insertEventJSONSQL},
+ {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL},
+ }.prepare(db)
+}
+
+func (s *eventJSONStatements) insertEventJSON(
+ ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
+) error {
+ _, err := common.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
+ return err
+}
+
+type eventJSONPair struct {
+ EventNID types.EventNID
+ EventJSON []byte
+}
+
+func (s *eventJSONStatements) bulkSelectEventJSON(
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
+) ([]eventJSONPair, error) {
+ iEventNIDs := make([]interface{}, len(eventNIDs))
+ for k, v := range eventNIDs {
+ iEventNIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
+
+ rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ // We know that we will only get as many results as event NIDs
+ // because of the unique constraint on event NIDs.
+ // So we can allocate an array of the correct size now.
+ // We might get fewer results than NIDs so we adjust the length of the slice before returning it.
+ results := make([]eventJSONPair, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ var eventNID int64
+ if err := rows.Scan(&eventNID, &result.EventJSON); err != nil {
+ return nil, err
+ }
+ result.EventNID = types.EventNID(eventNID)
+ }
+ return results[:i], nil
+}
diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go
new file mode 100644
index 00000000..b8bc6c02
--- /dev/null
+++ b/roomserver/storage/sqlite3/event_state_keys_table.go
@@ -0,0 +1,156 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const eventStateKeysSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_event_state_keys (
+ event_state_key_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_state_key TEXT NOT NULL UNIQUE
+ );
+ INSERT INTO roomserver_event_state_keys (event_state_key_nid, event_state_key)
+ VALUES (1, '')
+ ON CONFLICT DO NOTHING;
+`
+
+// Same as insertEventTypeNIDSQL
+const insertEventStateKeyNIDSQL = `
+ INSERT INTO roomserver_event_state_keys (event_state_key) VALUES ($1)
+ ON CONFLICT DO NOTHING;
+`
+
+const selectEventStateKeyNIDSQL = `
+ SELECT event_state_key_nid FROM roomserver_event_state_keys
+ WHERE event_state_key = $1
+`
+
+// Bulk lookup from string state key to numeric ID for that state key.
+// Takes an array of strings as the query parameter.
+const bulkSelectEventStateKeyNIDSQL = `
+ SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
+ WHERE event_state_key IN ($1)
+`
+
+// Bulk lookup from numeric ID to string state key for that state key.
+// Takes an array of strings as the query parameter.
+const bulkSelectEventStateKeySQL = `
+ SELECT event_state_key, event_state_key_nid FROM roomserver_event_state_keys
+ WHERE event_state_key_nid IN ($1)
+`
+
+type eventStateKeyStatements struct {
+ db *sql.DB
+ insertEventStateKeyNIDStmt *sql.Stmt
+ selectEventStateKeyNIDStmt *sql.Stmt
+ bulkSelectEventStateKeyNIDStmt *sql.Stmt
+ bulkSelectEventStateKeyStmt *sql.Stmt
+}
+
+func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ _, err = db.Exec(eventStateKeysSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertEventStateKeyNIDStmt, insertEventStateKeyNIDSQL},
+ {&s.selectEventStateKeyNIDStmt, selectEventStateKeyNIDSQL},
+ {&s.bulkSelectEventStateKeyNIDStmt, bulkSelectEventStateKeyNIDSQL},
+ {&s.bulkSelectEventStateKeyStmt, bulkSelectEventStateKeySQL},
+ }.prepare(db)
+}
+
+func (s *eventStateKeyStatements) insertEventStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKey string,
+) (types.EventStateKeyNID, error) {
+ var eventStateKeyNID int64
+ var err error
+ var res sql.Result
+ insertStmt := txn.Stmt(s.insertEventStateKeyNIDStmt)
+ if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
+ eventStateKeyNID, err = res.LastInsertId()
+ }
+ return types.EventStateKeyNID(eventStateKeyNID), err
+}
+
+func (s *eventStateKeyStatements) selectEventStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKey string,
+) (types.EventStateKeyNID, error) {
+ var eventStateKeyNID int64
+ stmt := txn.Stmt(s.selectEventStateKeyNIDStmt)
+ err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
+ return types.EventStateKeyNID(eventStateKeyNID), err
+}
+
+func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKeys []string,
+) (map[string]types.EventStateKeyNID, error) {
+ iEventStateKeys := make([]interface{}, len(eventStateKeys))
+ for k, v := range eventStateKeys {
+ iEventStateKeys[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeys)), 1)
+
+ rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ result := make(map[string]types.EventStateKeyNID, len(eventStateKeys))
+ for rows.Next() {
+ var stateKey string
+ var stateKeyNID int64
+ if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
+ return nil, err
+ }
+ result[stateKey] = types.EventStateKeyNID(stateKeyNID)
+ }
+ return result, nil
+}
+
+func (s *eventStateKeyStatements) bulkSelectEventStateKey(
+ ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
+) (map[types.EventStateKeyNID]string, error) {
+ iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
+ for k, v := range eventStateKeyNIDs {
+ iEventStateKeyNIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", common.QueryVariadic(len(eventStateKeyNIDs)), 1)
+
+ rows, err := txn.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ result := make(map[types.EventStateKeyNID]string, len(eventStateKeyNIDs))
+ for rows.Next() {
+ var stateKey string
+ var stateKeyNID int64
+ if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
+ return nil, err
+ }
+ result[types.EventStateKeyNID(stateKeyNID)] = stateKey
+ }
+ return result, nil
+}
diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go
new file mode 100644
index 00000000..edc06d4c
--- /dev/null
+++ b/roomserver/storage/sqlite3/event_types_table.go
@@ -0,0 +1,153 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const eventTypesSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_event_types (
+ event_type_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_type TEXT NOT NULL UNIQUE
+ );
+ INSERT INTO roomserver_event_types (event_type_nid, event_type) VALUES
+ (1, 'm.room.create'),
+ (2, 'm.room.power_levels'),
+ (3, 'm.room.join_rules'),
+ (4, 'm.room.third_party_invite'),
+ (5, 'm.room.member'),
+ (6, 'm.room.redaction'),
+ (7, 'm.room.history_visibility') ON CONFLICT DO NOTHING;
+`
+
+// Assign a new numeric event type ID.
+// The usual case is that the event type is not in the database.
+// In that case the ID will be assigned using the next value from the sequence.
+// We use `RETURNING` to tell postgres to return the assigned ID.
+// But it's possible that the type was added in a query that raced with us.
+// This will result in a conflict on the event_type_unique constraint, in this
+// case we do nothing. Postgresql won't return a row in that case so we rely on
+// the caller catching the sql.ErrNoRows error and running a select to get the row.
+// We could get postgresql to return the row on a conflict by updating the row
+// but it doesn't seem like a good idea to modify the rows just to make postgresql
+// return it. Modifying the rows will cause postgres to assign a new tuple for the
+// row even though the data doesn't change resulting in unncesssary modifications
+// to the indexes.
+const insertEventTypeNIDSQL = `
+ INSERT INTO roomserver_event_types (event_type) VALUES ($1)
+ ON CONFLICT DO NOTHING;
+`
+
+const insertEventTypeNIDResultSQL = `
+ SELECT event_type_nid FROM roomserver_event_types
+ WHERE rowid = last_insert_rowid();
+`
+
+const selectEventTypeNIDSQL = `
+ SELECT event_type_nid FROM roomserver_event_types WHERE event_type = $1
+`
+
+// Bulk lookup from string event type to numeric ID for that event type.
+// Takes an array of strings as the query parameter.
+const bulkSelectEventTypeNIDSQL = `
+ SELECT event_type, event_type_nid FROM roomserver_event_types
+ WHERE event_type IN ($1)
+`
+
+type eventTypeStatements struct {
+ db *sql.DB
+ insertEventTypeNIDStmt *sql.Stmt
+ insertEventTypeNIDResultStmt *sql.Stmt
+ selectEventTypeNIDStmt *sql.Stmt
+ bulkSelectEventTypeNIDStmt *sql.Stmt
+}
+
+func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ _, err = db.Exec(eventTypesSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertEventTypeNIDStmt, insertEventTypeNIDSQL},
+ {&s.insertEventTypeNIDResultStmt, insertEventTypeNIDResultSQL},
+ {&s.selectEventTypeNIDStmt, selectEventTypeNIDSQL},
+ {&s.bulkSelectEventTypeNIDStmt, bulkSelectEventTypeNIDSQL},
+ }.prepare(db)
+}
+
+func (s *eventTypeStatements) insertEventTypeNID(
+ ctx context.Context, tx *sql.Tx, eventType string,
+) (types.EventTypeNID, error) {
+ var eventTypeNID int64
+ var err error
+ insertStmt := common.TxStmt(tx, s.insertEventTypeNIDStmt)
+ resultStmt := common.TxStmt(tx, s.insertEventTypeNIDResultStmt)
+ if _, err = insertStmt.ExecContext(ctx, eventType); err == nil {
+ err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
+ }
+ return types.EventTypeNID(eventTypeNID), err
+}
+
+func (s *eventTypeStatements) selectEventTypeNID(
+ ctx context.Context, tx *sql.Tx, eventType string,
+) (types.EventTypeNID, error) {
+ var eventTypeNID int64
+ selectStmt := common.TxStmt(tx, s.selectEventTypeNIDStmt)
+ err := selectStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
+ return types.EventTypeNID(eventTypeNID), err
+}
+
+func (s *eventTypeStatements) bulkSelectEventTypeNID(
+ ctx context.Context, tx *sql.Tx, eventTypes []string,
+) (map[string]types.EventTypeNID, error) {
+ ///////////////
+ iEventTypes := make([]interface{}, len(eventTypes))
+ for k, v := range eventTypes {
+ iEventTypes[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectEventTypeNIDSQL, "($1)", common.QueryVariadic(len(iEventTypes)), 1)
+ selectPrep, err := s.db.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ ///////////////
+
+ selectStmt := common.TxStmt(tx, selectPrep)
+ rows, err := selectStmt.QueryContext(ctx, iEventTypes...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ result := make(map[string]types.EventTypeNID, len(eventTypes))
+ for rows.Next() {
+ var eventType string
+ var eventTypeNID int64
+ if err := rows.Scan(&eventType, &eventTypeNID); err != nil {
+ return nil, err
+ }
+ result[eventType] = types.EventTypeNID(eventTypeNID)
+ }
+ return result, nil
+}
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
new file mode 100644
index 00000000..4ed1395d
--- /dev/null
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -0,0 +1,479 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const eventsSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_events (
+ event_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ room_nid INTEGER NOT NULL,
+ event_type_nid INTEGER NOT NULL,
+ event_state_key_nid INTEGER NOT NULL,
+ sent_to_output BOOLEAN NOT NULL DEFAULT FALSE,
+ state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
+ depth INTEGER NOT NULL,
+ event_id TEXT NOT NULL UNIQUE,
+ reference_sha256 BLOB NOT NULL,
+ auth_event_nids TEXT NOT NULL DEFAULT '{}'
+ );
+`
+
+const insertEventSQL = `
+ INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
+ ON CONFLICT DO NOTHING;
+`
+
+const insertEventResultSQL = `
+ SELECT event_nid, state_snapshot_nid FROM roomserver_events
+ WHERE rowid = last_insert_rowid();
+`
+
+const selectEventSQL = "" +
+ "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
+
+// Bulk lookup of events by string ID.
+// Sort by the numeric IDs for event type and state key.
+// This means we can use binary search to lookup entries by type and state key.
+const bulkSelectStateEventByIDSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
+ " WHERE event_id IN ($1)" +
+ " ORDER BY event_type_nid, event_state_key_nid ASC"
+
+const bulkSelectStateAtEventByIDSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" +
+ " WHERE event_id IN ($1)"
+
+const updateEventStateSQL = "" +
+ "UPDATE roomserver_events SET state_snapshot_nid = $1 WHERE event_nid = $2"
+
+const selectEventSentToOutputSQL = "" +
+ "SELECT sent_to_output FROM roomserver_events WHERE event_nid = $1"
+
+const updateEventSentToOutputSQL = "" +
+ "UPDATE roomserver_events SET sent_to_output = TRUE WHERE event_nid = $1"
+
+const selectEventIDSQL = "" +
+ "SELECT event_id FROM roomserver_events WHERE event_nid = $1"
+
+const bulkSelectStateAtEventAndReferenceSQL = "" +
+ "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, event_id, reference_sha256" +
+ " FROM roomserver_events WHERE event_nid IN ($1)"
+
+const bulkSelectEventReferenceSQL = "" +
+ "SELECT event_id, reference_sha256 FROM roomserver_events WHERE event_nid IN ($1)"
+
+const bulkSelectEventIDSQL = "" +
+ "SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)"
+
+const bulkSelectEventNIDSQL = "" +
+ "SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
+
+const selectMaxEventDepthSQL = "" +
+ "SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
+
+type eventStatements struct {
+ db *sql.DB
+ insertEventStmt *sql.Stmt
+ insertEventResultStmt *sql.Stmt
+ selectEventStmt *sql.Stmt
+ bulkSelectStateEventByIDStmt *sql.Stmt
+ bulkSelectStateAtEventByIDStmt *sql.Stmt
+ updateEventStateStmt *sql.Stmt
+ selectEventSentToOutputStmt *sql.Stmt
+ updateEventSentToOutputStmt *sql.Stmt
+ selectEventIDStmt *sql.Stmt
+ bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
+ bulkSelectEventReferenceStmt *sql.Stmt
+ bulkSelectEventIDStmt *sql.Stmt
+ bulkSelectEventNIDStmt *sql.Stmt
+ selectMaxEventDepthStmt *sql.Stmt
+}
+
+func (s *eventStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ _, err = db.Exec(eventsSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertEventStmt, insertEventSQL},
+ {&s.insertEventResultStmt, insertEventResultSQL},
+ {&s.selectEventStmt, selectEventSQL},
+ {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
+ {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
+ {&s.updateEventStateStmt, updateEventStateSQL},
+ {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
+ {&s.selectEventSentToOutputStmt, selectEventSentToOutputSQL},
+ {&s.selectEventIDStmt, selectEventIDSQL},
+ {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL},
+ {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL},
+ {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
+ {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
+ {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
+ }.prepare(db)
+}
+
+func (s *eventStatements) insertEvent(
+ ctx context.Context,
+ txn *sql.Tx,
+ roomNID types.RoomNID,
+ eventTypeNID types.EventTypeNID,
+ eventStateKeyNID types.EventStateKeyNID,
+ eventID string,
+ referenceSHA256 []byte,
+ authEventNIDs []types.EventNID,
+ depth int64,
+) (types.EventNID, types.StateSnapshotNID, error) {
+ var eventNID int64
+ var stateNID int64
+ var err error
+ insertStmt := common.TxStmt(txn, s.insertEventStmt)
+ resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
+ if _, err = insertStmt.ExecContext(
+ ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
+ eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
+ ); err == nil {
+ err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
+ }
+ return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
+}
+
+func (s *eventStatements) selectEvent(
+ ctx context.Context, txn *sql.Tx, eventID string,
+) (types.EventNID, types.StateSnapshotNID, error) {
+ var eventNID int64
+ var stateNID int64
+ selectStmt := common.TxStmt(txn, s.selectEventStmt)
+ err := selectStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
+ return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
+}
+
+// bulkSelectStateEventByID lookups a list of state events by event ID.
+// If any of the requested events are missing from the database it returns a types.MissingEventError
+func (s *eventStatements) bulkSelectStateEventByID(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StateEntry, error) {
+ ///////////////
+ iEventIDs := make([]interface{}, len(eventIDs))
+ for k, v := range eventIDs {
+ iEventIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectStateEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
+ selectPrep, err := txn.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ ///////////////
+
+ selectStmt := common.TxStmt(txn, selectPrep)
+ rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ // We know that we will only get as many results as event IDs
+ // because of the unique constraint on event IDs.
+ // So we can allocate an array of the correct size now.
+ // We might get fewer results than IDs so we adjust the length of the slice before returning it.
+ results := make([]types.StateEntry, len(eventIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(
+ &result.EventTypeNID,
+ &result.EventStateKeyNID,
+ &result.EventNID,
+ ); err != nil {
+ return nil, err
+ }
+ }
+ if i != len(eventIDs) {
+ // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
+ // We don't know which ones were missing because we don't return the string IDs in the query.
+ // However it should be possible debug this by replaying queries or entries from the input kafka logs.
+ // If this turns out to be impossible and we do need the debug information here, it would be better
+ // to do it as a separate query rather than slowing down/complicating the common case.
+ return nil, types.MissingEventError(
+ fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
+ )
+ }
+ return results, err
+}
+
+// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
+// If any of the requested events are missing from the database it returns a types.MissingEventError.
+// If we do not have the state for any of the requested events it returns a types.MissingEventError.
+func (s *eventStatements) bulkSelectStateAtEventByID(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StateAtEvent, error) {
+ ///////////////
+ iEventIDs := make([]interface{}, len(eventIDs))
+ for k, v := range eventIDs {
+ iEventIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectStateAtEventByIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
+ selectPrep, err := txn.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ ///////////////
+
+ selectStmt := common.TxStmt(txn, selectPrep)
+ rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]types.StateAtEvent, len(eventIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(
+ &result.EventTypeNID,
+ &result.EventStateKeyNID,
+ &result.EventNID,
+ &result.BeforeStateSnapshotNID,
+ ); err != nil {
+ return nil, err
+ }
+ if result.BeforeStateSnapshotNID == 0 {
+ return nil, types.MissingEventError(
+ fmt.Sprintf("storage: missing state for event NID %d", result.EventNID),
+ )
+ }
+ }
+ if i != len(eventIDs) {
+ return nil, types.MissingEventError(
+ fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
+ )
+ }
+ return results, err
+}
+
+func (s *eventStatements) updateEventState(
+ ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID,
+) error {
+ updateStmt := common.TxStmt(txn, s.updateEventStateStmt)
+ _, err := updateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
+ return err
+}
+
+func (s *eventStatements) selectEventSentToOutput(
+ ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
+) (sentToOutput bool, err error) {
+ selectStmt := common.TxStmt(txn, s.selectEventSentToOutputStmt)
+ err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
+ //err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
+ if err != nil {
+ }
+ return
+}
+
+func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
+ updateStmt := common.TxStmt(txn, s.updateEventSentToOutputStmt)
+ _, err := updateStmt.ExecContext(ctx, int64(eventNID))
+ //_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID))
+ return err
+}
+
+func (s *eventStatements) selectEventID(
+ ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
+) (eventID string, err error) {
+ selectStmt := common.TxStmt(txn, s.selectEventIDStmt)
+ err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
+ return
+}
+
+func (s *eventStatements) bulkSelectStateAtEventAndReference(
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
+) ([]types.StateAtEventAndReference, error) {
+ ///////////////
+ iEventNIDs := make([]interface{}, len(eventNIDs))
+ for k, v := range eventNIDs {
+ iEventNIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectStateAtEventAndReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
+ //////////////
+
+ rows, err := txn.QueryContext(ctx, selectOrig, iEventNIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]types.StateAtEventAndReference, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ var (
+ eventTypeNID int64
+ eventStateKeyNID int64
+ eventNID int64
+ stateSnapshotNID int64
+ eventID string
+ eventSHA256 []byte
+ )
+ if err = rows.Scan(
+ &eventTypeNID, &eventStateKeyNID, &eventNID, &stateSnapshotNID, &eventID, &eventSHA256,
+ ); err != nil {
+ return nil, err
+ }
+ result := &results[i]
+ result.EventTypeNID = types.EventTypeNID(eventTypeNID)
+ result.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
+ result.EventNID = types.EventNID(eventNID)
+ result.BeforeStateSnapshotNID = types.StateSnapshotNID(stateSnapshotNID)
+ result.EventID = eventID
+ result.EventSHA256 = eventSHA256
+ }
+ if i != len(eventNIDs) {
+ return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
+ }
+ return results, nil
+}
+
+func (s *eventStatements) bulkSelectEventReference(
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
+) ([]gomatrixserverlib.EventReference, error) {
+ ///////////////
+ iEventNIDs := make([]interface{}, len(eventNIDs))
+ for k, v := range eventNIDs {
+ iEventNIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectEventReferenceSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
+ selectPrep, err := txn.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ ///////////////
+
+ selectStmt := common.TxStmt(txn, selectPrep)
+ rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]gomatrixserverlib.EventReference, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ if err = rows.Scan(&result.EventID, &result.EventSHA256); err != nil {
+ return nil, err
+ }
+ }
+ if i != len(eventNIDs) {
+ return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
+ }
+ return results, nil
+}
+
+// bulkSelectEventID returns a map from numeric event ID to string event ID.
+func (s *eventStatements) bulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
+ ///////////////
+ iEventNIDs := make([]interface{}, len(eventNIDs))
+ for k, v := range eventNIDs {
+ iEventNIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectEventIDSQL, "($1)", common.QueryVariadic(len(iEventNIDs)), 1)
+ selectPrep, err := txn.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ ///////////////
+
+ selectStmt := common.TxStmt(txn, selectPrep)
+ rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make(map[types.EventNID]string, len(eventNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ var eventNID int64
+ var eventID string
+ if err = rows.Scan(&eventNID, &eventID); err != nil {
+ return nil, err
+ }
+ results[types.EventNID(eventNID)] = eventID
+ }
+ if i != len(eventNIDs) {
+ return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
+ }
+ return results, nil
+}
+
+// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
+// If an event ID is not in the database then it is omitted from the map.
+func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
+ ///////////////
+ iEventIDs := make([]interface{}, len(eventIDs))
+ for k, v := range eventIDs {
+ iEventIDs[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectEventNIDSQL, "($1)", common.QueryVariadic(len(iEventIDs)), 1)
+ selectPrep, err := txn.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ ///////////////
+
+ selectStmt := common.TxStmt(txn, selectPrep)
+ rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make(map[string]types.EventNID, len(eventIDs))
+ for rows.Next() {
+ var eventID string
+ var eventNID int64
+ if err = rows.Scan(&eventID, &eventNID); err != nil {
+ return nil, err
+ }
+ results[eventID] = types.EventNID(eventNID)
+ }
+ return results, nil
+}
+
+func (s *eventStatements) selectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) {
+ var result int64
+ selectStmt := common.TxStmt(txn, s.selectMaxEventDepthStmt)
+ err := selectStmt.QueryRowContext(ctx, sqliteIn(eventNIDsAsArray(eventNIDs))).Scan(&result)
+ if err != nil {
+ return 0, err
+ }
+ return result, nil
+}
+
+func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
+ nids := make([]int64, len(eventNIDs))
+ for i := range eventNIDs {
+ nids[i] = int64(eventNIDs[i])
+ }
+ return nids
+}
diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go
new file mode 100644
index 00000000..5a0f0bf7
--- /dev/null
+++ b/roomserver/storage/sqlite3/invite_table.go
@@ -0,0 +1,142 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const inviteSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_invites (
+ invite_event_id TEXT PRIMARY KEY,
+ room_nid INTEGER NOT NULL,
+ target_nid INTEGER NOT NULL,
+ sender_nid INTEGER NOT NULL DEFAULT 0,
+ retired BOOLEAN NOT NULL DEFAULT FALSE,
+ invite_event_json TEXT NOT NULL
+ );
+
+ CREATE INDEX IF NOT EXISTS roomserver_invites_active_idx ON roomserver_invites (target_nid, room_nid)
+ WHERE NOT retired;
+`
+const insertInviteEventSQL = "" +
+ "INSERT INTO roomserver_invites (invite_event_id, room_nid, target_nid," +
+ " sender_nid, invite_event_json) VALUES ($1, $2, $3, $4, $5)" +
+ " ON CONFLICT DO NOTHING"
+
+const selectInviteActiveForUserInRoomSQL = "" +
+ "SELECT sender_nid FROM roomserver_invites" +
+ " WHERE target_nid = $1 AND room_nid = $2" +
+ " AND NOT retired"
+
+// Retire every active invite for a user in a room.
+// Ideally we'd know which invite events were retired by a given update so we
+// wouldn't need to remove every active invite.
+// However the matrix protocol doesn't give us a way to reliably identify the
+// invites that were retired, so we are forced to retire all of them.
+const updateInviteRetiredSQL = `
+ UPDATE roomserver_invites SET retired = TRUE
+ WHERE room_nid = $1 AND target_nid = $2 AND NOT retired;
+ SELECT invite_event_id FROM roomserver_invites
+ WHERE rowid = last_insert_rowid();
+`
+
+type inviteStatements struct {
+ insertInviteEventStmt *sql.Stmt
+ selectInviteActiveForUserInRoomStmt *sql.Stmt
+ updateInviteRetiredStmt *sql.Stmt
+}
+
+func (s *inviteStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(inviteSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertInviteEventStmt, insertInviteEventSQL},
+ {&s.selectInviteActiveForUserInRoomStmt, selectInviteActiveForUserInRoomSQL},
+ {&s.updateInviteRetiredStmt, updateInviteRetiredSQL},
+ }.prepare(db)
+}
+
+func (s *inviteStatements) insertInviteEvent(
+ ctx context.Context,
+ txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
+ targetUserNID, senderUserNID types.EventStateKeyNID,
+ inviteEventJSON []byte,
+) (bool, error) {
+ stmt := common.TxStmt(txn, s.insertInviteEventStmt)
+ defer stmt.Close() // nolint: errcheck
+ result, err := stmt.ExecContext(
+ ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
+ )
+ if err != nil {
+ return false, err
+ }
+ count, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return count != 0, nil
+}
+
+func (s *inviteStatements) updateInviteRetired(
+ ctx context.Context,
+ txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) (eventIDs []string, err error) {
+ stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
+ rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+ defer (func() { err = rows.Close() })()
+ for rows.Next() {
+ var inviteEventID string
+ if err := rows.Scan(&inviteEventID); err != nil {
+ return nil, err
+ }
+ eventIDs = append(eventIDs, inviteEventID)
+ }
+ return
+}
+
+// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
+func (s *inviteStatements) selectInviteActiveForUserInRoom(
+ ctx context.Context,
+ targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
+) ([]types.EventStateKeyNID, error) {
+ rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
+ ctx, targetUserNID, roomNID,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ var result []types.EventStateKeyNID
+ for rows.Next() {
+ var senderUserNID int64
+ if err := rows.Scan(&senderUserNID); err != nil {
+ return nil, err
+ }
+ result = append(result, types.EventStateKeyNID(senderUserNID))
+ }
+ return result, nil
+}
diff --git a/roomserver/storage/sqlite3/list.go b/roomserver/storage/sqlite3/list.go
new file mode 100644
index 00000000..4fe4e334
--- /dev/null
+++ b/roomserver/storage/sqlite3/list.go
@@ -0,0 +1,18 @@
+package sqlite3
+
+import (
+ "strconv"
+ "strings"
+
+ "github.com/lib/pq"
+)
+
+type SqliteList string
+
+func sqliteIn(a pq.Int64Array) string {
+ var b []string
+ for _, n := range a {
+ b = append(b, strconv.FormatInt(n, 10))
+ }
+ return strings.Join(b, ",")
+}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
new file mode 100644
index 00000000..97877673
--- /dev/null
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -0,0 +1,180 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+type membershipState int64
+
+const (
+ membershipStateLeaveOrBan membershipState = 1
+ membershipStateInvite membershipState = 2
+ membershipStateJoin membershipState = 3
+)
+
+const membershipSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_membership (
+ room_nid INTEGER NOT NULL,
+ target_nid INTEGER NOT NULL,
+ sender_nid INTEGER NOT NULL DEFAULT 0,
+ membership_nid INTEGER NOT NULL DEFAULT 1,
+ event_nid INTEGER NOT NULL DEFAULT 0,
+ UNIQUE (room_nid, target_nid)
+ );
+`
+
+// Insert a row in to membership table so that it can be locked by the
+// SELECT FOR UPDATE
+const insertMembershipSQL = "" +
+ "INSERT INTO roomserver_membership (room_nid, target_nid)" +
+ " VALUES ($1, $2)" +
+ " ON CONFLICT DO NOTHING"
+
+const selectMembershipFromRoomAndTargetSQL = "" +
+ "SELECT membership_nid, event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND target_nid = $2"
+
+const selectMembershipsFromRoomAndMembershipSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND membership_nid = $2"
+
+const selectMembershipsFromRoomSQL = "" +
+ "SELECT event_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1"
+
+const selectMembershipForUpdateSQL = "" +
+ "SELECT membership_nid FROM roomserver_membership" +
+ " WHERE room_nid = $1 AND target_nid = $2"
+
+const updateMembershipSQL = "" +
+ "UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" +
+ " WHERE room_nid = $4 AND target_nid = $5"
+
+type membershipStatements struct {
+ insertMembershipStmt *sql.Stmt
+ selectMembershipForUpdateStmt *sql.Stmt
+ selectMembershipFromRoomAndTargetStmt *sql.Stmt
+ selectMembershipsFromRoomAndMembershipStmt *sql.Stmt
+ selectMembershipsFromRoomStmt *sql.Stmt
+ updateMembershipStmt *sql.Stmt
+}
+
+func (s *membershipStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(membershipSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertMembershipStmt, insertMembershipSQL},
+ {&s.selectMembershipForUpdateStmt, selectMembershipForUpdateSQL},
+ {&s.selectMembershipFromRoomAndTargetStmt, selectMembershipFromRoomAndTargetSQL},
+ {&s.selectMembershipsFromRoomAndMembershipStmt, selectMembershipsFromRoomAndMembershipSQL},
+ {&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
+ {&s.updateMembershipStmt, updateMembershipSQL},
+ }.prepare(db)
+}
+
+func (s *membershipStatements) insertMembership(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) error {
+ stmt := common.TxStmt(txn, s.insertMembershipStmt)
+ _, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
+ return err
+}
+
+func (s *membershipStatements) selectMembershipForUpdate(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) (membership membershipState, err error) {
+ stmt := common.TxStmt(txn, s.selectMembershipForUpdateStmt)
+ err = stmt.QueryRowContext(
+ ctx, roomNID, targetUserNID,
+ ).Scan(&membership)
+ return
+}
+
+func (s *membershipStatements) selectMembershipFromRoomAndTarget(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+) (eventNID types.EventNID, membership membershipState, err error) {
+ selectStmt := common.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
+ err = selectStmt.QueryRowContext(
+ ctx, roomNID, targetUserNID,
+ ).Scan(&membership, &eventNID)
+ return
+}
+
+func (s *membershipStatements) selectMembershipsFromRoom(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID,
+) (eventNIDs []types.EventNID, err error) {
+ selectStmt := common.TxStmt(txn, s.selectMembershipsFromRoomStmt)
+ rows, err := selectStmt.QueryContext(ctx, roomNID)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ for rows.Next() {
+ var eNID types.EventNID
+ if err = rows.Scan(&eNID); err != nil {
+ return
+ }
+ eventNIDs = append(eventNIDs, eNID)
+ }
+ return
+}
+func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, membership membershipState,
+) (eventNIDs []types.EventNID, err error) {
+ stmt := common.TxStmt(txn, s.selectMembershipsFromRoomAndMembershipStmt)
+ rows, err := stmt.QueryContext(ctx, roomNID, membership)
+ if err != nil {
+ return
+ }
+ defer rows.Close() // nolint: errcheck
+
+ for rows.Next() {
+ var eNID types.EventNID
+ if err = rows.Scan(&eNID); err != nil {
+ return
+ }
+ eventNIDs = append(eventNIDs, eNID)
+ }
+ return
+}
+
+func (s *membershipStatements) updateMembership(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ senderUserNID types.EventStateKeyNID, membership membershipState,
+ eventNID types.EventNID,
+) error {
+ stmt := common.TxStmt(txn, s.updateMembershipStmt)
+ _, err := stmt.ExecContext(
+ ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
+ )
+ return err
+}
diff --git a/roomserver/storage/sqlite3/prepare.go b/roomserver/storage/sqlite3/prepare.go
new file mode 100644
index 00000000..482dfa2b
--- /dev/null
+++ b/roomserver/storage/sqlite3/prepare.go
@@ -0,0 +1,36 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "database/sql"
+)
+
+// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
+type statementList []struct {
+ statement **sql.Stmt
+ sql string
+}
+
+// prepare the SQL for each statement in the list and assign the result to the prepared statement.
+func (s statementList) prepare(db *sql.DB) (err error) {
+ for _, statement := range s {
+ if *statement.statement, err = db.Prepare(statement.sql); err != nil {
+ return
+ }
+ }
+ return
+}
diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go
new file mode 100644
index 00000000..9ed64a38
--- /dev/null
+++ b/roomserver/storage/sqlite3/previous_events_table.go
@@ -0,0 +1,92 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const previousEventSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_previous_events (
+ previous_event_id TEXT NOT NULL,
+ previous_reference_sha256 BLOB NOT NULL,
+ event_nids TEXT NOT NULL,
+ UNIQUE (previous_event_id, previous_reference_sha256)
+ );
+`
+
+// Insert an entry into the previous_events table.
+// If there is already an entry indicating that an event references that previous event then
+// add the event NID to the list to indicate that this event references that previous event as well.
+// This should only be modified while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
+// The lock is necessary to avoid data races when checking whether an event is already referenced by another event.
+const insertPreviousEventSQL = `
+ INSERT OR REPLACE INTO roomserver_previous_events
+ (previous_event_id, previous_reference_sha256, event_nids)
+ VALUES ($1, $2, $3)
+`
+
+// Check if the event is referenced by another event in the table.
+// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
+const selectPreviousEventExistsSQL = `
+ SELECT 1 FROM roomserver_previous_events
+ WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
+`
+
+type previousEventStatements struct {
+ insertPreviousEventStmt *sql.Stmt
+ selectPreviousEventExistsStmt *sql.Stmt
+}
+
+func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(previousEventSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertPreviousEventStmt, insertPreviousEventSQL},
+ {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
+ }.prepare(db)
+}
+
+func (s *previousEventStatements) insertPreviousEvent(
+ ctx context.Context,
+ txn *sql.Tx,
+ previousEventID string,
+ previousEventReferenceSHA256 []byte,
+ eventNID types.EventNID,
+) error {
+ stmt := common.TxStmt(txn, s.insertPreviousEventStmt)
+ _, err := stmt.ExecContext(
+ ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
+ )
+ return err
+}
+
+// Check if the event reference exists
+// Returns sql.ErrNoRows if the event reference doesn't exist.
+func (s *previousEventStatements) selectPreviousEventExists(
+ ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
+) error {
+ var ok int64
+ stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt)
+ return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
+}
diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go
new file mode 100644
index 00000000..a5fd5449
--- /dev/null
+++ b/roomserver/storage/sqlite3/room_aliases_table.go
@@ -0,0 +1,135 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+)
+
+const roomAliasesSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_room_aliases (
+ alias TEXT NOT NULL PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ creator_id TEXT NOT NULL
+ );
+
+ CREATE INDEX IF NOT EXISTS roomserver_room_id_idx ON roomserver_room_aliases(room_id);
+`
+
+const insertRoomAliasSQL = `
+ INSERT INTO roomserver_room_aliases (alias, room_id, creator_id) VALUES ($1, $2, $3)
+`
+
+const selectRoomIDFromAliasSQL = `
+ SELECT room_id FROM roomserver_room_aliases WHERE alias = $1
+`
+
+const selectAliasesFromRoomIDSQL = `
+ SELECT alias FROM roomserver_room_aliases WHERE room_id = $1
+`
+
+const selectCreatorIDFromAliasSQL = `
+ SELECT creator_id FROM roomserver_room_aliases WHERE alias = $1
+`
+
+const deleteRoomAliasSQL = `
+ DELETE FROM roomserver_room_aliases WHERE alias = $1
+`
+
+type roomAliasesStatements struct {
+ insertRoomAliasStmt *sql.Stmt
+ selectRoomIDFromAliasStmt *sql.Stmt
+ selectAliasesFromRoomIDStmt *sql.Stmt
+ selectCreatorIDFromAliasStmt *sql.Stmt
+ deleteRoomAliasStmt *sql.Stmt
+}
+
+func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(roomAliasesSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertRoomAliasStmt, insertRoomAliasSQL},
+ {&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
+ {&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
+ {&s.selectCreatorIDFromAliasStmt, selectCreatorIDFromAliasSQL},
+ {&s.deleteRoomAliasStmt, deleteRoomAliasSQL},
+ }.prepare(db)
+}
+
+func (s *roomAliasesStatements) insertRoomAlias(
+ ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string,
+) (err error) {
+ insertStmt := common.TxStmt(txn, s.insertRoomAliasStmt)
+ _, err = insertStmt.ExecContext(ctx, alias, roomID, creatorUserID)
+ return
+}
+
+func (s *roomAliasesStatements) selectRoomIDFromAlias(
+ ctx context.Context, txn *sql.Tx, alias string,
+) (roomID string, err error) {
+ selectStmt := common.TxStmt(txn, s.selectRoomIDFromAliasStmt)
+ err = selectStmt.QueryRowContext(ctx, alias).Scan(&roomID)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return
+}
+
+func (s *roomAliasesStatements) selectAliasesFromRoomID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (aliases []string, err error) {
+ aliases = []string{}
+ selectStmt := common.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
+ rows, err := selectStmt.QueryContext(ctx, roomID)
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var alias string
+ if err = rows.Scan(&alias); err != nil {
+ return
+ }
+
+ aliases = append(aliases, alias)
+ }
+
+ return
+}
+
+func (s *roomAliasesStatements) selectCreatorIDFromAlias(
+ ctx context.Context, txn *sql.Tx, alias string,
+) (creatorID string, err error) {
+ selectStmt := common.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
+ err = selectStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return
+}
+
+func (s *roomAliasesStatements) deleteRoomAlias(
+ ctx context.Context, txn *sql.Tx, alias string,
+) (err error) {
+ deleteStmt := common.TxStmt(txn, s.deleteRoomAliasStmt)
+ _, err = deleteStmt.ExecContext(ctx, alias)
+ return
+}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
new file mode 100644
index 00000000..bf237728
--- /dev/null
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -0,0 +1,165 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const roomsSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_rooms (
+ room_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ room_id TEXT NOT NULL UNIQUE,
+ latest_event_nids TEXT NOT NULL DEFAULT '{}',
+ last_event_sent_nid INTEGER NOT NULL DEFAULT 0,
+ state_snapshot_nid INTEGER NOT NULL DEFAULT 0,
+ room_version INTEGER NOT NULL DEFAULT 1
+ );
+`
+
+// Same as insertEventTypeNIDSQL
+const insertRoomNIDSQL = `
+ INSERT INTO roomserver_rooms (room_id) VALUES ($1)
+ ON CONFLICT DO NOTHING;
+`
+
+const selectRoomNIDSQL = "" +
+ "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1"
+
+const selectLatestEventNIDsSQL = "" +
+ "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
+
+const selectLatestEventNIDsForUpdateSQL = "" +
+ "SELECT latest_event_nids, last_event_sent_nid, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1"
+
+const updateLatestEventNIDsSQL = "" +
+ "UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
+
+const selectRoomVersionForRoomNIDSQL = "" +
+ "SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
+
+type roomStatements struct {
+ insertRoomNIDStmt *sql.Stmt
+ selectRoomNIDStmt *sql.Stmt
+ selectLatestEventNIDsStmt *sql.Stmt
+ selectLatestEventNIDsForUpdateStmt *sql.Stmt
+ updateLatestEventNIDsStmt *sql.Stmt
+ selectRoomVersionForRoomNIDStmt *sql.Stmt
+}
+
+func (s *roomStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(roomsSchema)
+ if err != nil {
+ return
+ }
+ return statementList{
+ {&s.insertRoomNIDStmt, insertRoomNIDSQL},
+ {&s.selectRoomNIDStmt, selectRoomNIDSQL},
+ {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
+ {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
+ {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
+ {&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
+ }.prepare(db)
+}
+
+func (s *roomStatements) insertRoomNID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (types.RoomNID, error) {
+ var err error
+ insertStmt := common.TxStmt(txn, s.insertRoomNIDStmt)
+ if _, err = insertStmt.ExecContext(ctx, roomID); err == nil {
+ return s.selectRoomNID(ctx, txn, roomID)
+ } else {
+ return types.RoomNID(0), err
+ }
+}
+
+func (s *roomStatements) selectRoomNID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (types.RoomNID, error) {
+ var roomNID int64
+ stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
+ err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
+ return types.RoomNID(roomNID), err
+}
+
+func (s *roomStatements) selectLatestEventNIDs(
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
+) ([]types.EventNID, types.StateSnapshotNID, error) {
+ var nids pq.Int64Array
+ var stateSnapshotNID int64
+ stmt := common.TxStmt(txn, s.selectLatestEventNIDsStmt)
+ err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
+ if err != nil {
+ return nil, 0, err
+ }
+ eventNIDs := make([]types.EventNID, len(nids))
+ for i := range nids {
+ eventNIDs[i] = types.EventNID(nids[i])
+ }
+ return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
+}
+
+func (s *roomStatements) selectLatestEventsNIDsForUpdate(
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
+) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
+ var nids pq.Int64Array
+ var lastEventSentNID int64
+ var stateSnapshotNID int64
+ stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
+ err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
+ if err != nil {
+ return nil, 0, 0, err
+ }
+ eventNIDs := make([]types.EventNID, len(nids))
+ for i := range nids {
+ eventNIDs[i] = types.EventNID(nids[i])
+ }
+ return eventNIDs, types.EventNID(lastEventSentNID), types.StateSnapshotNID(stateSnapshotNID), nil
+}
+
+func (s *roomStatements) updateLatestEventNIDs(
+ ctx context.Context,
+ txn *sql.Tx,
+ roomNID types.RoomNID,
+ eventNIDs []types.EventNID,
+ lastEventSentNID types.EventNID,
+ stateSnapshotNID types.StateSnapshotNID,
+) error {
+ stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ eventNIDsAsArray(eventNIDs),
+ int64(lastEventSentNID),
+ int64(stateSnapshotNID),
+ roomNID,
+ )
+ return err
+}
+
+func (s *roomStatements) selectRoomVersionForRoomNID(
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
+) (int64, error) {
+ var roomVersion int64
+ stmt := common.TxStmt(txn, s.selectRoomVersionForRoomNIDStmt)
+ err := stmt.QueryRowContext(ctx, roomNID).Scan(&roomVersion)
+ return roomVersion, err
+}
diff --git a/roomserver/storage/sqlite3/sql.go b/roomserver/storage/sqlite3/sql.go
new file mode 100644
index 00000000..0d49432b
--- /dev/null
+++ b/roomserver/storage/sqlite3/sql.go
@@ -0,0 +1,60 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "database/sql"
+)
+
+type statements struct {
+ eventTypeStatements
+ eventStateKeyStatements
+ roomStatements
+ eventStatements
+ eventJSONStatements
+ stateSnapshotStatements
+ stateBlockStatements
+ previousEventStatements
+ roomAliasesStatements
+ inviteStatements
+ membershipStatements
+ transactionStatements
+}
+
+func (s *statements) prepare(db *sql.DB) error {
+ var err error
+
+ for _, prepare := range []func(db *sql.DB) error{
+ s.eventTypeStatements.prepare,
+ s.eventStateKeyStatements.prepare,
+ s.roomStatements.prepare,
+ s.eventStatements.prepare,
+ s.eventJSONStatements.prepare,
+ s.stateSnapshotStatements.prepare,
+ s.stateBlockStatements.prepare,
+ s.previousEventStatements.prepare,
+ s.roomAliasesStatements.prepare,
+ s.inviteStatements.prepare,
+ s.membershipStatements.prepare,
+ s.transactionStatements.prepare,
+ } {
+ if err = prepare(db); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go
new file mode 100644
index 00000000..ac593546
--- /dev/null
+++ b/roomserver/storage/sqlite3/state_block_table.go
@@ -0,0 +1,292 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "sort"
+ "strings"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/util"
+)
+
+const stateDataSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_state_block (
+ state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_type_nid INTEGER NOT NULL,
+ event_state_key_nid INTEGER NOT NULL,
+ event_nid INTEGER NOT NULL,
+ UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
+ );
+`
+
+const insertStateDataSQL = "" +
+ "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
+ " VALUES ($1, $2, $3, $4)"
+
+const selectNextStateBlockNIDSQL = `
+ SELECT COALESCE((
+ SELECT seq+1 AS state_block_nid FROM sqlite_sequence
+ WHERE name = 'roomserver_state_block'), 1
+ ) AS state_block_nid
+`
+
+// Bulk state lookup by numeric state block ID.
+// Sort by the state_block_nid, event_type_nid, event_state_key_nid
+// This means that all the entries for a given state_block_nid will appear
+// together in the list and those entries will sorted by event_type_nid
+// and event_state_key_nid. This property makes it easier to merge two
+// state data blocks together.
+const bulkSelectStateBlockEntriesSQL = "" +
+ "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
+ " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
+ " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
+
+// Bulk state lookup by numeric state block ID.
+// Filters the rows in each block to the requested types and state keys.
+// We would like to restrict to particular type state key pairs but we are
+// restricted by the query language to pull the cross product of a list
+// of types and a list state_keys. So we have to filter the result in the
+// application to restrict it to the list of event types and state keys we
+// actually wanted.
+const bulkSelectFilteredStateBlockEntriesSQL = "" +
+ "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
+ " FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
+ " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
+ " ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
+
+type stateBlockStatements struct {
+ db *sql.DB
+ insertStateDataStmt *sql.Stmt
+ selectNextStateBlockNIDStmt *sql.Stmt
+ bulkSelectStateBlockEntriesStmt *sql.Stmt
+ bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
+}
+
+func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ _, err = db.Exec(stateDataSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertStateDataStmt, insertStateDataSQL},
+ {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
+ {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
+ {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
+ }.prepare(db)
+}
+
+func (s *stateBlockStatements) bulkInsertStateData(
+ ctx context.Context, txn *sql.Tx,
+ stateBlockNID types.StateBlockNID,
+ entries []types.StateEntry,
+) error {
+ for _, entry := range entries {
+ _, err := common.TxStmt(txn, s.insertStateDataStmt).ExecContext(
+ ctx,
+ int64(stateBlockNID),
+ int64(entry.EventTypeNID),
+ int64(entry.EventStateKeyNID),
+ int64(entry.EventNID),
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *stateBlockStatements) selectNextStateBlockNID(
+ ctx context.Context,
+ txn *sql.Tx,
+) (types.StateBlockNID, error) {
+ var stateBlockNID int64
+ selectStmt := common.TxStmt(txn, s.selectNextStateBlockNIDStmt)
+ err := selectStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
+ return types.StateBlockNID(stateBlockNID), err
+}
+
+func (s *stateBlockStatements) bulkSelectStateBlockEntries(
+ ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
+) ([]types.StateEntryList, error) {
+ nids := make([]interface{}, len(stateBlockNIDs))
+ for k, v := range stateBlockNIDs {
+ nids[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(nids)), 1)
+ selectPrep, err := s.db.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ selectStmt := common.TxStmt(txn, selectPrep)
+ rows, err := selectStmt.QueryContext(ctx, nids...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ results := make([]types.StateEntryList, len(stateBlockNIDs))
+ // current is a pointer to the StateEntryList to append the state entries to.
+ var current *types.StateEntryList
+ i := 0
+ for rows.Next() {
+ var (
+ stateBlockNID int64
+ eventTypeNID int64
+ eventStateKeyNID int64
+ eventNID int64
+ entry types.StateEntry
+ )
+ if err := rows.Scan(
+ &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
+ ); err != nil {
+ return nil, err
+ }
+ entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
+ entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
+ entry.EventNID = types.EventNID(eventNID)
+ if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
+ // The state entry row is for a different state data block to the current one.
+ // So we start appending to the next entry in the list.
+ current = &results[i]
+ current.StateBlockNID = types.StateBlockNID(stateBlockNID)
+ i++
+ }
+ current.StateEntries = append(current.StateEntries, entry)
+ }
+ if i != len(nids) {
+ return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
+ }
+ return results, nil
+}
+
+func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
+ ctx context.Context, txn *sql.Tx, // nolint: unparam
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntryList, error) {
+ tuples := stateKeyTupleSorter(stateKeyTuples)
+ // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
+ sort.Sort(tuples)
+
+ eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
+ sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", common.QueryVariadic(len(stateBlockNIDs)), 1)
+ sqlStatement = strings.Replace(sqlStatement, "($2)", common.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
+ sqlStatement = strings.Replace(sqlStatement, "($3)", common.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
+
+ var params []interface{}
+ for _, val := range stateBlockNIDs {
+ params = append(params, int64(val))
+ }
+ for _, val := range eventTypeNIDArray {
+ params = append(params, val)
+ }
+ for _, val := range eventStateKeyNIDArray {
+ params = append(params, val)
+ }
+
+ rows, err := s.db.QueryContext(
+ ctx,
+ sqlStatement,
+ params...,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ var results []types.StateEntryList
+ var current types.StateEntryList
+ for rows.Next() {
+ var (
+ stateBlockNID int64
+ eventTypeNID int64
+ eventStateKeyNID int64
+ eventNID int64
+ entry types.StateEntry
+ )
+ if err := rows.Scan(
+ &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
+ ); err != nil {
+ return nil, err
+ }
+ entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
+ entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
+ entry.EventNID = types.EventNID(eventNID)
+
+ // We can use binary search here because we sorted the tuples earlier
+ if !tuples.contains(entry.StateKeyTuple) {
+ // The select will return the cross product of types and state keys.
+ // So we need to check if type of the entry is in the list.
+ continue
+ }
+
+ if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
+ // The state entry row is for a different state data block to the current one.
+ // So we append the current entry to the results and start adding to a new one.
+ // The first time through the loop current will be empty.
+ if current.StateEntries != nil {
+ results = append(results, current)
+ }
+ current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
+ }
+ current.StateEntries = append(current.StateEntries, entry)
+ }
+ // Add the last entry to the list if it is not empty.
+ if current.StateEntries != nil {
+ results = append(results, current)
+ }
+ return results, nil
+}
+
+type stateKeyTupleSorter []types.StateKeyTuple
+
+func (s stateKeyTupleSorter) Len() int { return len(s) }
+func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
+func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// Check whether a tuple is in the list. Assumes that the list is sorted.
+func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
+ i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
+ return i < len(s) && s[i] == value
+}
+
+// List the unique eventTypeNIDs and eventStateKeyNIDs.
+// Assumes that the list is sorted.
+func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
+ eventTypeNIDs = make(pq.Int64Array, len(s))
+ eventStateKeyNIDs = make(pq.Int64Array, len(s))
+ for i := range s {
+ eventTypeNIDs[i] = int64(s[i].EventTypeNID)
+ eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
+ }
+ eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
+ eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
+ return
+}
+
+type int64Sorter []int64
+
+func (s int64Sorter) Len() int { return len(s) }
+func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
+func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/roomserver/storage/sqlite3/state_block_table_test.go b/roomserver/storage/sqlite3/state_block_table_test.go
new file mode 100644
index 00000000..98439f5c
--- /dev/null
+++ b/roomserver/storage/sqlite3/state_block_table_test.go
@@ -0,0 +1,86 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "sort"
+ "testing"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+func TestStateKeyTupleSorter(t *testing.T) {
+ input := stateKeyTupleSorter{
+ {EventTypeNID: 1, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 4},
+ {EventTypeNID: 2, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 1},
+ }
+ want := []types.StateKeyTuple{
+ {EventTypeNID: 1, EventStateKeyNID: 1},
+ {EventTypeNID: 1, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 4},
+ {EventTypeNID: 2, EventStateKeyNID: 2},
+ }
+ doNotWant := []types.StateKeyTuple{
+ {EventTypeNID: 0, EventStateKeyNID: 0},
+ {EventTypeNID: 1, EventStateKeyNID: 3},
+ {EventTypeNID: 2, EventStateKeyNID: 1},
+ {EventTypeNID: 3, EventStateKeyNID: 1},
+ }
+ wantTypeNIDs := []int64{1, 2}
+ wantStateKeyNIDs := []int64{1, 2, 4}
+
+ // Sort the input and check it's in the right order.
+ sort.Sort(input)
+ gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
+
+ for i := range want {
+ if input[i] != want[i] {
+ t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
+ }
+
+ if !input.contains(want[i]) {
+ t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
+ }
+ }
+
+ for i := range doNotWant {
+ if input.contains(doNotWant[i]) {
+ t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
+ }
+ }
+
+ if len(wantTypeNIDs) != len(gotTypeNIDs) {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+
+ for i := range wantTypeNIDs {
+ if wantTypeNIDs[i] != gotTypeNIDs[i] {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+ }
+
+ if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
+ t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
+ }
+
+ for i := range wantStateKeyNIDs {
+ if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+ }
+}
diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go
new file mode 100644
index 00000000..df97aa41
--- /dev/null
+++ b/roomserver/storage/sqlite3/state_snapshot_table.go
@@ -0,0 +1,120 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "strings"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/types"
+)
+
+const stateSnapshotSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
+ state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
+ room_nid INTEGER NOT NULL,
+ state_block_nids TEXT NOT NULL DEFAULT '{}'
+ );
+`
+
+const insertStateSQL = `
+ INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
+ VALUES ($1, $2);`
+
+// Bulk state data NID lookup.
+// Sorting by state_snapshot_nid means we can use binary search over the result
+// to lookup the state data NIDs for a state snapshot NID.
+const bulkSelectStateBlockNIDsSQL = "" +
+ "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
+ " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
+
+type stateSnapshotStatements struct {
+ db *sql.DB
+ insertStateStmt *sql.Stmt
+ bulkSelectStateBlockNIDsStmt *sql.Stmt
+}
+
+func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ _, err = db.Exec(stateSnapshotSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertStateStmt, insertStateSQL},
+ {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
+ }.prepare(db)
+}
+
+func (s *stateSnapshotStatements) insertState(
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
+) (stateNID types.StateSnapshotNID, err error) {
+ nids := make([]int64, len(stateBlockNIDs))
+ for i := range stateBlockNIDs {
+ nids[i] = int64(stateBlockNIDs[i])
+ }
+ insertStmt := txn.Stmt(s.insertStateStmt)
+ if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), pq.Int64Array(nids)); err2 == nil {
+ lastRowID, err3 := res.LastInsertId()
+ if err3 != nil {
+ err = err3
+ }
+ stateNID = types.StateSnapshotNID(lastRowID)
+ }
+ return
+}
+
+func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
+ ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
+) ([]types.StateBlockNIDList, error) {
+ nids := make([]interface{}, len(stateNIDs))
+ for k, v := range stateNIDs {
+ nids[k] = v
+ }
+ selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", common.QueryVariadic(len(nids)), 1)
+ selectStmt, err := txn.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+
+ rows, err := selectStmt.QueryContext(ctx, nids...)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ results := make([]types.StateBlockNIDList, len(stateNIDs))
+ i := 0
+ for ; rows.Next(); i++ {
+ result := &results[i]
+ var stateBlockNIDs pq.Int64Array
+ if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
+ return nil, err
+ }
+ result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
+ for k := range stateBlockNIDs {
+ result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
+ }
+ }
+ if i != len(stateNIDs) {
+ return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
+ }
+ return results, nil
+}
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
new file mode 100644
index 00000000..e20e8aed
--- /dev/null
+++ b/roomserver/storage/sqlite3/storage.go
@@ -0,0 +1,864 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "net/url"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ _ "github.com/mattn/go-sqlite3"
+)
+
+// A Database is used to store room events and stream offsets.
+type Database struct {
+ statements statements
+ db *sql.DB
+}
+
+// Open a postgres database.
+func Open(dataSourceName string) (*Database, error) {
+ var d Database
+ uri, err := url.Parse(dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ var cs string
+ if uri.Opaque != "" { // file:filename.db
+ cs = uri.Opaque
+ } else if uri.Path != "" { // file:///path/to/filename.db
+ cs = uri.Path
+ } else {
+ return nil, errors.New("no filename or path in connect string")
+ }
+ if d.db, err = sql.Open("sqlite3", cs); err != nil {
+ return nil, err
+ }
+ //d.db.Exec("PRAGMA journal_mode=WAL;")
+ //d.db.Exec("PRAGMA read_uncommitted = true;")
+ d.db.SetMaxOpenConns(2)
+ if err = d.statements.prepare(d.db); err != nil {
+ return nil, err
+ }
+ return &d, nil
+}
+
+// StoreEvent implements input.EventDatabase
+func (d *Database) StoreEvent(
+ ctx context.Context, event gomatrixserverlib.Event,
+ txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
+) (types.RoomNID, types.StateAtEvent, error) {
+ var (
+ roomNID types.RoomNID
+ eventTypeNID types.EventTypeNID
+ eventStateKeyNID types.EventStateKeyNID
+ eventNID types.EventNID
+ stateNID types.StateSnapshotNID
+ err error
+ )
+
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if txnAndSessionID != nil {
+ if err = d.statements.insertTransaction(
+ ctx, txn, txnAndSessionID.TransactionID,
+ txnAndSessionID.SessionID, event.Sender(), event.EventID(),
+ ); err != nil {
+ return err
+ }
+ }
+
+ if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil {
+ return err
+ }
+
+ if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
+ return err
+ }
+
+ eventStateKey := event.StateKey()
+ // Assigned a numeric ID for the state_key if there is one present.
+ // Otherwise set the numeric ID for the state_key to 0.
+ if eventStateKey != nil {
+ if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
+ return err
+ }
+ }
+
+ if eventNID, stateNID, err = d.statements.insertEvent(
+ ctx,
+ txn,
+ roomNID,
+ eventTypeNID,
+ eventStateKeyNID,
+ event.EventID(),
+ event.EventReference().EventSHA256,
+ authEventNIDs,
+ event.Depth(),
+ ); err != nil {
+ if err == sql.ErrNoRows {
+ // We've already inserted the event so select the numeric event ID
+ eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID())
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
+ return err
+ }
+
+ return nil
+ })
+ if err != nil {
+ return 0, types.StateAtEvent{}, err
+ }
+
+ return roomNID, types.StateAtEvent{
+ BeforeStateSnapshotNID: stateNID,
+ StateEntry: types.StateEntry{
+ StateKeyTuple: types.StateKeyTuple{
+ EventTypeNID: eventTypeNID,
+ EventStateKeyNID: eventStateKeyNID,
+ },
+ EventNID: eventNID,
+ },
+ }, nil
+}
+
+func (d *Database) assignRoomNID(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (roomNID types.RoomNID, err error) {
+ // Check if we already have a numeric ID in the database.
+ roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
+ if err == nil {
+ // Now get the numeric ID back out of the database
+ roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
+ }
+ }
+ return
+}
+
+func (d *Database) assignEventTypeNID(
+ ctx context.Context, txn *sql.Tx, eventType string,
+) (eventTypeNID types.EventTypeNID, err error) {
+ // Check if we already have a numeric ID in the database.
+ eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType)
+ if err == sql.ErrNoRows {
+ // We raced with another insert so run the select again.
+ eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
+ }
+ }
+ return
+}
+
+func (d *Database) assignStateKeyNID(
+ ctx context.Context, txn *sql.Tx, eventStateKey string,
+) (eventStateKeyNID types.EventStateKeyNID, err error) {
+ // Check if we already have a numeric ID in the database.
+ eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
+ if err == sql.ErrNoRows {
+ // We don't have a numeric ID so insert one into the database.
+ eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
+ if err == sql.ErrNoRows {
+ // We raced with another insert so run the select again.
+ eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
+ }
+ }
+ return
+}
+
+// StateEntriesForEventIDs implements input.EventDatabase
+func (d *Database) StateEntriesForEventIDs(
+ ctx context.Context, eventIDs []string,
+) (se []types.StateEntry, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs)
+ return err
+ })
+ return
+}
+
+// EventTypeNIDs implements state.RoomStateDatabase
+func (d *Database) EventTypeNIDs(
+ ctx context.Context, eventTypes []string,
+) (etnids map[string]types.EventTypeNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes)
+ return err
+ })
+ return
+}
+
+// EventStateKeyNIDs implements state.RoomStateDatabase
+func (d *Database) EventStateKeyNIDs(
+ ctx context.Context, eventStateKeys []string,
+) (esknids map[string]types.EventStateKeyNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
+ return err
+ })
+ return
+}
+
+// EventStateKeys implements query.RoomserverQueryAPIDatabase
+func (d *Database) EventStateKeys(
+ ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
+) (out map[types.EventStateKeyNID]string, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs)
+ return err
+ })
+ return
+}
+
+// EventNIDs implements query.RoomserverQueryAPIDatabase
+func (d *Database) EventNIDs(
+ ctx context.Context, eventIDs []string,
+) (out map[string]types.EventNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs)
+ return err
+ })
+ return
+}
+
+// Events implements input.EventDatabase
+func (d *Database) Events(
+ ctx context.Context, eventNIDs []types.EventNID,
+) ([]types.Event, error) {
+ var eventJSONs []eventJSONPair
+ var err error
+ results := make([]types.Event, len(eventNIDs))
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
+ if err != nil || len(eventJSONs) == 0 {
+ return nil
+ }
+ for i, eventJSON := range eventJSONs {
+ result := &results[i]
+ result.EventNID = eventJSON.EventNID
+ // TODO: Use NewEventFromTrustedJSON for efficiency
+ result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
+ if err != nil {
+ return nil
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ return []types.Event{}, err
+ }
+ return results, nil
+}
+
+// AddState implements input.EventDatabase
+func (d *Database) AddState(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ stateBlockNIDs []types.StateBlockNID,
+ state []types.StateEntry,
+) (stateNID types.StateSnapshotNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if len(state) > 0 {
+ var stateBlockNID types.StateBlockNID
+ stateBlockNID, err = d.statements.selectNextStateBlockNID(ctx, txn)
+ if err != nil {
+ return err
+ }
+ if err = d.statements.bulkInsertStateData(ctx, txn, stateBlockNID, state); err != nil {
+ return err
+ }
+ stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
+ }
+ stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
+ return err
+ })
+ if err != nil {
+ return 0, err
+ }
+ return
+}
+
+// SetState implements input.EventDatabase
+func (d *Database) SetState(
+ ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
+) error {
+ e := common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.statements.updateEventState(ctx, txn, eventNID, stateNID)
+ })
+ return e
+}
+
+// StateAtEventIDs implements input.EventDatabase
+func (d *Database) StateAtEventIDs(
+ ctx context.Context, eventIDs []string,
+) (se []types.StateAtEvent, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs)
+ return err
+ })
+ return
+}
+
+// StateBlockNIDs implements state.RoomStateDatabase
+func (d *Database) StateBlockNIDs(
+ ctx context.Context, stateNIDs []types.StateSnapshotNID,
+) (sl []types.StateBlockNIDList, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
+ return err
+ })
+ return
+}
+
+// StateEntries implements state.RoomStateDatabase
+func (d *Database) StateEntries(
+ ctx context.Context, stateBlockNIDs []types.StateBlockNID,
+) (sel []types.StateEntryList, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
+ return err
+ })
+ return
+}
+
+// SnapshotNIDFromEventID implements state.RoomStateDatabase
+func (d *Database) SnapshotNIDFromEventID(
+ ctx context.Context, eventID string,
+) (stateNID types.StateSnapshotNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ _, stateNID, err = d.statements.selectEvent(ctx, txn, eventID)
+ return err
+ })
+ return
+}
+
+// EventIDs implements input.RoomEventDatabase
+func (d *Database) EventIDs(
+ ctx context.Context, eventNIDs []types.EventNID,
+) (out map[types.EventNID]string, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs)
+ return err
+ })
+ return
+}
+
+// GetLatestEventsForUpdate implements input.EventDatabase
+func (d *Database) GetLatestEventsForUpdate(
+ ctx context.Context, roomNID types.RoomNID,
+) (types.RoomRecentEventsUpdater, error) {
+ txn, err := d.db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
+ d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ var lastEventIDSent string
+ if lastEventNIDSent != 0 {
+ lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
+ if err != nil {
+ txn.Rollback() // nolint: errcheck
+ return nil, err
+ }
+ }
+
+ // FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get
+ // 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here)
+ // so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so
+ // we fail fast if someone tries to use the underlying txn object.
+ err = txn.Commit()
+ if err != nil {
+ return nil, err
+ }
+ return &roomRecentEventsUpdater{
+ transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
+ }, nil
+}
+
+// GetTransactionEventID implements input.EventDatabase
+func (d *Database) GetTransactionEventID(
+ ctx context.Context, transactionID string,
+ sessionID int64, userID string,
+) (string, error) {
+ eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID)
+ if err == sql.ErrNoRows {
+ return "", nil
+ }
+ return eventID, err
+}
+
+type roomRecentEventsUpdater struct {
+ transaction
+ d *Database
+ roomNID types.RoomNID
+ latestEvents []types.StateAtEventAndReference
+ lastEventIDSent string
+ currentStateSnapshotNID types.StateSnapshotNID
+}
+
+// LatestEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
+ return u.latestEvents
+}
+
+// LastEventIDSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) LastEventIDSent() string {
+ return u.lastEventIDSent
+}
+
+// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
+ return u.currentStateSnapshotNID
+}
+
+// StorePreviousEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
+ err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ for _, ref := range previousEventReferences {
+ if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ return err
+}
+
+// IsReferenced implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256)
+ if err == nil {
+ res = true
+ err = nil
+ }
+ if err == sql.ErrNoRows {
+ res = false
+ err = nil
+ }
+ return err
+ })
+ return
+}
+
+// SetLatestEvents implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) SetLatestEvents(
+ roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
+ currentStateSnapshotNID types.StateSnapshotNID,
+) error {
+ err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ eventNIDs := make([]types.EventNID, len(latest))
+ for i := range latest {
+ eventNIDs[i] = latest[i].EventNID
+ }
+ return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
+ })
+ return err
+}
+
+// HasEventBeenSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID)
+ return err
+ })
+ return
+}
+
+// MarkEventAsSent implements types.RoomRecentEventsUpdater
+func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
+ err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID)
+ })
+ return err
+}
+
+func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
+ return err
+ })
+ return
+}
+
+// RoomNID implements query.RoomserverQueryAPIDB
+func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
+ if err == sql.ErrNoRows {
+ roomNID = 0
+ err = nil
+ }
+ return err
+ })
+ return
+}
+
+// LatestEventIDs implements query.RoomserverQueryAPIDatabase
+func (d *Database) LatestEventIDs(
+ ctx context.Context, roomNID types.RoomNID,
+) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var eventNIDs []types.EventNID
+ eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID)
+ if err != nil {
+ return err
+ }
+ references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs)
+ if err != nil {
+ return err
+ }
+ depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
+// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
+func (d *Database) GetInvitesForUser(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ targetUserNID types.EventStateKeyNID,
+) (senderUserIDs []types.EventStateKeyNID, err error) {
+ return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
+}
+
+// SetRoomAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
+ return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID)
+}
+
+// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
+ return d.statements.selectRoomIDFromAlias(ctx, nil, alias)
+}
+
+// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
+func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
+ return d.statements.selectAliasesFromRoomID(ctx, nil, roomID)
+}
+
+// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) GetCreatorIDForAlias(
+ ctx context.Context, alias string,
+) (string, error) {
+ return d.statements.selectCreatorIDFromAlias(ctx, nil, alias)
+}
+
+// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
+func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
+ return d.statements.deleteRoomAlias(ctx, nil, alias)
+}
+
+// StateEntriesForTuples implements state.RoomStateDatabase
+func (d *Database) StateEntriesForTuples(
+ ctx context.Context,
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntryList, error) {
+ return d.statements.bulkSelectFilteredStateBlockEntries(
+ ctx, nil, stateBlockNIDs, stateKeyTuples,
+ )
+}
+
+// MembershipUpdater implements input.RoomEventDatabase
+func (d *Database) MembershipUpdater(
+ ctx context.Context, roomID, targetUserID string,
+) (types.MembershipUpdater, error) {
+ txn, err := d.db.Begin()
+ if err != nil {
+ return nil, err
+ }
+ succeeded := false
+ defer func() {
+ if !succeeded {
+ txn.Rollback() // nolint: errcheck
+ }
+ }()
+
+ roomNID, err := d.assignRoomNID(ctx, txn, roomID)
+ if err != nil {
+ return nil, err
+ }
+
+ targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
+ if err != nil {
+ return nil, err
+ }
+
+ updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+
+ succeeded = true
+ return updater, nil
+}
+
+type membershipUpdater struct {
+ transaction
+ d *Database
+ roomNID types.RoomNID
+ targetUserNID types.EventStateKeyNID
+ membership membershipState
+}
+
+func (d *Database) membershipUpdaterTxn(
+ ctx context.Context,
+ txn *sql.Tx,
+ roomNID types.RoomNID,
+ targetUserNID types.EventStateKeyNID,
+) (types.MembershipUpdater, error) {
+
+ if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
+ return nil, err
+ }
+
+ membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
+ if err != nil {
+ return nil, err
+ }
+
+ return &membershipUpdater{
+ transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
+ }, nil
+}
+
+// IsInvite implements types.MembershipUpdater
+func (u *membershipUpdater) IsInvite() bool {
+ return u.membership == membershipStateInvite
+}
+
+// IsJoin implements types.MembershipUpdater
+func (u *membershipUpdater) IsJoin() bool {
+ return u.membership == membershipStateJoin
+}
+
+// IsLeave implements types.MembershipUpdater
+func (u *membershipUpdater) IsLeave() bool {
+ return u.membership == membershipStateLeaveOrBan
+}
+
+// SetToInvite implements types.MembershipUpdater
+func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender())
+ if err != nil {
+ return err
+ }
+ inserted, err = u.d.statements.insertInviteEvent(
+ u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
+ )
+ if err != nil {
+ return err
+ }
+ if u.membership != membershipStateInvite {
+ if err = u.d.statements.updateMembership(
+ u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
+ ); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ return
+}
+
+// SetToJoin implements types.MembershipUpdater
+func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
+ if err != nil {
+ return err
+ }
+
+ // If this is a join event update, there is no invite to update
+ if !isUpdate {
+ inviteEventIDs, err = u.d.statements.updateInviteRetired(
+ u.ctx, txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return err
+ }
+ }
+
+ // Look up the NID of the new join event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return err
+ }
+
+ if u.membership != membershipStateJoin || isUpdate {
+ if err = u.d.statements.updateMembership(
+ u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
+ membershipStateJoin, nIDs[eventID],
+ ); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+
+ return
+}
+
+// SetToLeave implements types.MembershipUpdater
+func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) {
+ err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
+ senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
+ if err != nil {
+ return err
+ }
+ inviteEventIDs, err = u.d.statements.updateInviteRetired(
+ u.ctx, txn, u.roomNID, u.targetUserNID,
+ )
+ if err != nil {
+ return err
+ }
+
+ // Look up the NID of the new leave event
+ nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
+ if err != nil {
+ return err
+ }
+
+ if u.membership != membershipStateLeaveOrBan {
+ if err = u.d.statements.updateMembership(
+ u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
+ membershipStateLeaveOrBan, nIDs[eventID],
+ ); err != nil {
+ return err
+ }
+ }
+ return nil
+ })
+ return
+}
+
+// GetMembership implements query.RoomserverQueryAPIDB
+func (d *Database) GetMembership(
+ ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
+) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID)
+ if err != nil {
+ return err
+ }
+
+ membershipEventNID, _, err =
+ d.statements.selectMembershipFromRoomAndTarget(
+ ctx, txn, roomNID, requestSenderUserNID,
+ )
+ if err == sql.ErrNoRows {
+ // The user has never been a member of that room
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ stillInRoom = true
+ return nil
+ })
+
+ return
+}
+
+// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
+func (d *Database) GetMembershipEventNIDsForRoom(
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool,
+) (eventNIDs []types.EventNID, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if joinOnly {
+ eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
+ ctx, txn, roomNID, membershipStateJoin,
+ )
+ return nil
+ }
+
+ eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
+ return nil
+ })
+ return
+}
+
+// EventsFromIDs implements query.RoomserverQueryAPIEventDB
+func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
+ nidMap, err := d.EventNIDs(ctx, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ var nids []types.EventNID
+ for _, nid := range nidMap {
+ nids = append(nids, nid)
+ }
+
+ return d.Events(ctx, nids)
+}
+
+func (d *Database) GetRoomVersionForRoom(
+ ctx context.Context, roomNID types.RoomNID,
+) (int64, error) {
+ return d.statements.selectRoomVersionForRoomNID(
+ ctx, nil, roomNID,
+ )
+}
+
+type transaction struct {
+ ctx context.Context
+ txn *sql.Tx
+}
+
+// Commit implements types.Transaction
+func (t *transaction) Commit() error {
+ if t.txn == nil {
+ return nil
+ }
+ return t.txn.Commit()
+}
+
+// Rollback implements types.Transaction
+func (t *transaction) Rollback() error {
+ if t.txn == nil {
+ return nil
+ }
+ return t.txn.Rollback()
+}
diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go
new file mode 100644
index 00000000..7740e5f0
--- /dev/null
+++ b/roomserver/storage/sqlite3/transactions_table.go
@@ -0,0 +1,86 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+)
+
+const transactionsSchema = `
+ CREATE TABLE IF NOT EXISTS roomserver_transactions (
+ transaction_id TEXT NOT NULL,
+ session_id INTEGER NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ PRIMARY KEY (transaction_id, session_id, user_id)
+ );
+`
+const insertTransactionSQL = `
+ INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)
+ VALUES ($1, $2, $3, $4)
+`
+
+const selectTransactionEventIDSQL = `
+ SELECT event_id FROM roomserver_transactions
+ WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3
+`
+
+type transactionStatements struct {
+ insertTransactionStmt *sql.Stmt
+ selectTransactionEventIDStmt *sql.Stmt
+}
+
+func (s *transactionStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(transactionsSchema)
+ if err != nil {
+ return
+ }
+
+ return statementList{
+ {&s.insertTransactionStmt, insertTransactionSQL},
+ {&s.selectTransactionEventIDStmt, selectTransactionEventIDSQL},
+ }.prepare(db)
+}
+
+func (s *transactionStatements) insertTransaction(
+ ctx context.Context, txn *sql.Tx,
+ transactionID string,
+ sessionID int64,
+ userID string,
+ eventID string,
+) (err error) {
+ stmt := common.TxStmt(txn, s.insertTransactionStmt)
+ _, err = stmt.ExecContext(
+ ctx, transactionID, sessionID, userID, eventID,
+ )
+ return
+}
+
+func (s *transactionStatements) selectTransactionEventID(
+ ctx context.Context, txn *sql.Tx,
+ transactionID string,
+ sessionID int64,
+ userID string,
+) (eventID string, err error) {
+ stmt := common.TxStmt(txn, s.selectTransactionEventIDStmt)
+ err = stmt.QueryRowContext(
+ ctx, transactionID, sessionID, userID,
+ ).Scan(&eventID)
+ return
+}
diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go
index 90841168..551d97cd 100644
--- a/roomserver/storage/storage.go
+++ b/roomserver/storage/storage.go
@@ -19,25 +19,20 @@ import (
"net/url"
"github.com/matrix-org/dendrite/roomserver/api"
+ statedb "github.com/matrix-org/dendrite/roomserver/state/database"
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
+ "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
+ statedb.RoomStateDatabase
StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
- EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
- EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
- Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
- AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
- StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
- StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
- StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
- SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error)
GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error)
@@ -49,7 +44,6 @@ type Database interface {
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
RemoveRoomAlias(ctx context.Context, alias string) error
- StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
MembershipUpdater(ctx context.Context, roomID, targetUserID string) (types.MembershipUpdater, error)
GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
@@ -66,6 +60,8 @@ func Open(dataSourceName string) (Database, error) {
switch uri.Scheme {
case "postgres":
return postgres.Open(dataSourceName)
+ case "file":
+ return sqlite3.Open(dataSourceName)
default:
return postgres.Open(dataSourceName)
}
diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go
index 8916565d..be90e0a0 100644
--- a/syncapi/routing/routing.go
+++ b/syncapi/routing/routing.go
@@ -39,7 +39,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
// nolint: gocyclo
func Setup(
apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database,
- deviceDB *devices.Database, federation *gomatrixserverlib.FederationClient,
+ deviceDB devices.Database, federation *gomatrixserverlib.FederationClient,
queryAPI api.RoomserverQueryAPI,
cfg *config.Dendrite,
) {
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index aec37185..6a33a8b4 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -413,13 +413,18 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
numRecentEventsPerRoom int,
wantFullState bool,
res *types.Response,
-) ([]string, error) {
+) (joinedRoomIDs []string, err error) {
txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
if err != nil {
return nil, err
}
var succeeded bool
- defer common.EndTransaction(txn, &succeeded)
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
@@ -428,7 +433,6 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// This works out what the 'state' key should be for each room as well as which membership block
// to put the room into.
var deltas []stateDelta
- var joinedRoomIDs []string
if !wantFullState {
deltas, joinedRoomIDs, err = d.getStateDeltas(
ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter,
@@ -570,7 +574,12 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
return
}
var succeeded bool
- defer common.EndTransaction(txn, &succeeded)
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
// Get the current sync position which we will base the sync response on.
toPos, err = d.syncPositionTx(ctx, txn)
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
new file mode 100644
index 00000000..3274e66e
--- /dev/null
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -0,0 +1,143 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const accountDataSchema = `
+CREATE TABLE IF NOT EXISTS syncapi_account_data_type (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ UNIQUE (user_id, room_id, type)
+);
+`
+
+const insertAccountDataSQL = "" +
+ "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id, room_id, type) DO UPDATE" +
+ " SET id = EXCLUDED.id"
+
+const selectAccountDataInRangeSQL = "" +
+ "SELECT room_id, type FROM syncapi_account_data_type" +
+ " WHERE user_id = $1 AND id > $2 AND id <= $3" +
+ " AND ( $4 IS NULL OR type IN ($4) )" +
+ " AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
+ " ORDER BY id ASC LIMIT $6"
+
+const selectMaxAccountDataIDSQL = "" +
+ "SELECT MAX(id) FROM syncapi_account_data_type"
+
+type accountDataStatements struct {
+ streamIDStatements *streamIDStatements
+ insertAccountDataStmt *sql.Stmt
+ selectAccountDataInRangeStmt *sql.Stmt
+ selectMaxAccountDataIDStmt *sql.Stmt
+}
+
+func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(accountDataSchema)
+ if err != nil {
+ return
+ }
+ if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
+ return
+ }
+ if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil {
+ return
+ }
+ if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *accountDataStatements) insertAccountData(
+ ctx context.Context, txn *sql.Tx,
+ userID, roomID, dataType string,
+) (pos types.StreamPosition, err error) {
+ pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
+ insertStmt := common.TxStmt(txn, s.insertAccountDataStmt)
+ _, err = insertStmt.ExecContext(ctx, pos, userID, roomID, dataType)
+ return
+}
+
+func (s *accountDataStatements) selectAccountDataInRange(
+ ctx context.Context,
+ userID string,
+ oldPos, newPos types.StreamPosition,
+ accountDataFilterPart *gomatrixserverlib.EventFilter,
+) (data map[string][]string, err error) {
+ data = make(map[string][]string)
+
+ // If both positions are the same, it means that the data was saved after the
+ // latest room event. In that case, we need to decrement the old position as
+ // it would prevent the SQL request from returning anything.
+ if oldPos == newPos {
+ oldPos--
+ }
+
+ rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos,
+ pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.Types)),
+ pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.NotTypes)),
+ accountDataFilterPart.Limit,
+ )
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var dataType string
+ var roomID string
+
+ if err = rows.Scan(&roomID, &dataType); err != nil {
+ return
+ }
+
+ if len(data[roomID]) > 0 {
+ data[roomID] = append(data[roomID], dataType)
+ } else {
+ data[roomID] = []string{dataType}
+ }
+ }
+
+ return
+}
+
+func (s *accountDataStatements) selectMaxAccountDataID(
+ ctx context.Context, txn *sql.Tx,
+) (id int64, err error) {
+ var nullableID sql.NullInt64
+ stmt := common.TxStmt(txn, s.selectMaxAccountDataIDStmt)
+ err = stmt.QueryRowContext(ctx).Scan(&nullableID)
+ if nullableID.Valid {
+ id = nullableID.Int64
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/backward_extremities_table.go b/syncapi/storage/sqlite3/backward_extremities_table.go
new file mode 100644
index 00000000..fcf15da2
--- /dev/null
+++ b/syncapi/storage/sqlite3/backward_extremities_table.go
@@ -0,0 +1,124 @@
+// Copyright 2018 New Vector Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+)
+
+const backwardExtremitiesSchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_backward_extremities (
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+
+ PRIMARY KEY(room_id, event_id)
+);
+`
+
+const insertBackwardExtremitySQL = "" +
+ "INSERT INTO syncapi_backward_extremities (room_id, event_id)" +
+ " VALUES ($1, $2)" +
+ " ON CONFLICT (room_id, event_id) DO NOTHING"
+
+const selectBackwardExtremitiesForRoomSQL = "" +
+ "SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1"
+
+const isBackwardExtremitySQL = "" +
+ "SELECT EXISTS (" +
+ " SELECT TRUE FROM syncapi_backward_extremities" +
+ " WHERE room_id = $1 AND event_id = $2" +
+ ")"
+
+const deleteBackwardExtremitySQL = "" +
+ "DELETE FROM syncapi_backward_extremities" +
+ " WHERE room_id = $1 AND event_id = $2"
+
+type backwardExtremitiesStatements struct {
+ insertBackwardExtremityStmt *sql.Stmt
+ selectBackwardExtremitiesForRoomStmt *sql.Stmt
+ isBackwardExtremityStmt *sql.Stmt
+ deleteBackwardExtremityStmt *sql.Stmt
+}
+
+func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(backwardExtremitiesSchema)
+ if err != nil {
+ return
+ }
+ if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
+ return
+ }
+ if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
+ return
+ }
+ if s.isBackwardExtremityStmt, err = db.Prepare(isBackwardExtremitySQL); err != nil {
+ return
+ }
+ if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *backwardExtremitiesStatements) insertsBackwardExtremity(
+ ctx context.Context, txn *sql.Tx, roomID, eventID string,
+) (err error) {
+ stmt := common.TxStmt(txn, s.insertBackwardExtremityStmt)
+ _, err = stmt.ExecContext(ctx, roomID, eventID)
+ return
+}
+
+func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (eventIDs []string, err error) {
+ eventIDs = make([]string, 0)
+
+ stmt := common.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt)
+ rows, err := stmt.QueryContext(ctx, roomID)
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var eID string
+ if err = rows.Scan(&eID); err != nil {
+ return
+ }
+
+ eventIDs = append(eventIDs, eID)
+ }
+
+ return
+}
+
+func (s *backwardExtremitiesStatements) isBackwardExtremity(
+ ctx context.Context, txn *sql.Tx, roomID, eventID string,
+) (isBE bool, err error) {
+ stmt := common.TxStmt(txn, s.isBackwardExtremityStmt)
+ err = stmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE)
+ return
+}
+
+func (s *backwardExtremitiesStatements) deleteBackwardExtremity(
+ ctx context.Context, txn *sql.Tx, roomID, eventID string,
+) (err error) {
+ stmt := common.TxStmt(txn, s.deleteBackwardExtremityStmt)
+ _, err = stmt.ExecContext(ctx, roomID, eventID)
+ return
+}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
new file mode 100644
index 00000000..4ce94666
--- /dev/null
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -0,0 +1,276 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const currentRoomStateSchema = `
+-- Stores the current room state for every room.
+CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ sender TEXT NOT NULL,
+ contains_url BOOL NOT NULL DEFAULT false,
+ state_key TEXT NOT NULL,
+ event_json TEXT NOT NULL,
+ membership TEXT,
+ added_at BIGINT,
+ UNIQUE (room_id, type, state_key)
+);
+-- for event deletion
+CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
+-- for querying membership states of users
+-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
+`
+
+const upsertRoomStateSQL = "" +
+ "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, event_json, membership, added_at)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
+ " ON CONFLICT (event_id, room_id, type, sender, contains_url)" +
+ " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, event_json = $7, membership = $8, added_at = $9"
+
+const deleteRoomStateByEventIDSQL = "" +
+ "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
+
+const selectRoomIDsWithMembershipSQL = "" +
+ "SELECT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
+
+const selectCurrentStateSQL = "" +
+ "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1" +
+ " AND ( $2 IS NULL OR sender IN ($2) )" +
+ " AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
+ " AND ( $4 IS NULL OR type IN ($4) )" +
+ " AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
+ " AND ( $6 IS NULL OR contains_url = $6 )" +
+ " LIMIT $7"
+
+const selectJoinedUsersSQL = "" +
+ "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
+
+const selectStateEventSQL = "" +
+ "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
+
+const selectEventsWithEventIDsSQL = "" +
+ // TODO: The session_id and transaction_id blanks are here because otherwise
+ // the rowsToStreamEvents expects there to be exactly five columns. We need to
+ // figure out if these really need to be in the DB, and if so, we need a
+ // better permanent fix for this. - neilalexander, 2 Jan 2020
+ "SELECT added_at, event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
+ " FROM syncapi_current_room_state WHERE event_id IN ($1)"
+
+type currentRoomStateStatements struct {
+ streamIDStatements *streamIDStatements
+ upsertRoomStateStmt *sql.Stmt
+ deleteRoomStateByEventIDStmt *sql.Stmt
+ selectRoomIDsWithMembershipStmt *sql.Stmt
+ selectCurrentStateStmt *sql.Stmt
+ selectJoinedUsersStmt *sql.Stmt
+ selectEventsWithEventIDsStmt *sql.Stmt
+ selectStateEventStmt *sql.Stmt
+}
+
+func (s *currentRoomStateStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(currentRoomStateSchema)
+ if err != nil {
+ return
+ }
+ if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
+ return
+ }
+ if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
+ return
+ }
+ if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
+ return
+ }
+ if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
+ return
+ }
+ if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
+ return
+ }
+ if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
+ return
+ }
+ if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
+ return
+ }
+ return
+}
+
+// JoinedMemberLists returns a map of room ID to a list of joined user IDs.
+func (s *currentRoomStateStatements) selectJoinedUsers(
+ ctx context.Context,
+) (map[string][]string, error) {
+ rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ result := make(map[string][]string)
+ for rows.Next() {
+ var roomID string
+ var userID string
+ if err := rows.Scan(&roomID, &userID); err != nil {
+ return nil, err
+ }
+ users := result[roomID]
+ users = append(users, userID)
+ result[roomID] = users
+ }
+ return result, nil
+}
+
+// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
+func (s *currentRoomStateStatements) selectRoomIDsWithMembership(
+ ctx context.Context,
+ txn *sql.Tx,
+ userID string,
+ membership string, // nolint: unparam
+) ([]string, error) {
+ stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
+ rows, err := stmt.QueryContext(ctx, userID, membership)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ var result []string
+ for rows.Next() {
+ var roomID string
+ if err := rows.Scan(&roomID); err != nil {
+ return nil, err
+ }
+ result = append(result, roomID)
+ }
+ return result, nil
+}
+
+// CurrentState returns all the current state events for the given room.
+func (s *currentRoomStateStatements) selectCurrentState(
+ ctx context.Context, txn *sql.Tx, roomID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]gomatrixserverlib.Event, error) {
+ stmt := common.TxStmt(txn, s.selectCurrentStateStmt)
+ rows, err := stmt.QueryContext(ctx, roomID,
+ pq.StringArray(stateFilterPart.Senders),
+ pq.StringArray(stateFilterPart.NotSenders),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
+ stateFilterPart.ContainsURL,
+ stateFilterPart.Limit,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ return rowsToEvents(rows)
+}
+
+func (s *currentRoomStateStatements) deleteRoomStateByEventID(
+ ctx context.Context, txn *sql.Tx, eventID string,
+) error {
+ stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
+ _, err := stmt.ExecContext(ctx, eventID)
+ return err
+}
+
+func (s *currentRoomStateStatements) upsertRoomState(
+ ctx context.Context, txn *sql.Tx,
+ event gomatrixserverlib.Event, membership *string, addedAt types.StreamPosition,
+) error {
+ // Parse content as JSON and search for an "url" key
+ containsURL := false
+ var content map[string]interface{}
+ if json.Unmarshal(event.Content(), &content) != nil {
+ // Set containsURL to true if url is present
+ _, containsURL = content["url"]
+ }
+
+ // upsert state event
+ stmt := common.TxStmt(txn, s.upsertRoomStateStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ event.RoomID(),
+ event.EventID(),
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ *event.StateKey(),
+ event.JSON(),
+ membership,
+ addedAt,
+ )
+ return err
+}
+
+func (s *currentRoomStateStatements) selectEventsWithEventIDs(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StreamEvent, error) {
+ stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt)
+ rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ return rowsToStreamEvents(rows)
+}
+
+func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
+ result := []gomatrixserverlib.Event{}
+ for rows.Next() {
+ var eventBytes []byte
+ if err := rows.Scan(&eventBytes); err != nil {
+ return nil, err
+ }
+ // TODO: Handle redacted events
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, ev)
+ }
+ return result, nil
+}
+
+func (s *currentRoomStateStatements) selectStateEvent(
+ ctx context.Context, roomID, evType, stateKey string,
+) (*gomatrixserverlib.Event, error) {
+ stmt := s.selectStateEventStmt
+ var res []byte
+ err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(res, false)
+ return &ev, err
+}
diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go
new file mode 100644
index 00000000..c4a2f4bf
--- /dev/null
+++ b/syncapi/storage/sqlite3/filtering.go
@@ -0,0 +1,36 @@
+// Copyright 2017 Thibaut CHARLES
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "strings"
+)
+
+// filterConvertWildcardToSQL converts wildcards as defined in
+// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
+// to SQL wildcards that can be used with LIKE()
+func filterConvertTypeWildcardToSQL(values []string) []string {
+ if values == nil {
+ // Return nil instead of []string{} so IS NULL can work correctly when
+ // the return value is passed into SQL queries
+ return nil
+ }
+
+ ret := make([]string, len(values))
+ for i := range values {
+ ret[i] = strings.Replace(values[i], "*", "%", -1)
+ }
+ return ret
+}
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
new file mode 100644
index 00000000..74dba245
--- /dev/null
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -0,0 +1,157 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const inviteEventsSchema = `
+CREATE TABLE IF NOT EXISTS syncapi_invite_events (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ event_json TEXT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
+CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id);
+`
+
+const insertInviteEventSQL = "" +
+ "INSERT INTO syncapi_invite_events" +
+ " (room_id, event_id, target_user_id, event_json)" +
+ " VALUES ($1, $2, $3, $4)"
+
+const selectLastInsertedInviteEventSQL = "" +
+ "SELECT id FROM syncapi_invite_events WHERE rowid = last_insert_rowid()"
+
+const deleteInviteEventSQL = "" +
+ "DELETE FROM syncapi_invite_events WHERE event_id = $1"
+
+const selectInviteEventsInRangeSQL = "" +
+ "SELECT room_id, event_json FROM syncapi_invite_events" +
+ " WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
+ " ORDER BY id DESC"
+
+const selectMaxInviteIDSQL = "" +
+ "SELECT MAX(id) FROM syncapi_invite_events"
+
+type inviteEventsStatements struct {
+ streamIDStatements *streamIDStatements
+ insertInviteEventStmt *sql.Stmt
+ selectLastInsertedInviteEventStmt *sql.Stmt
+ selectInviteEventsInRangeStmt *sql.Stmt
+ deleteInviteEventStmt *sql.Stmt
+ selectMaxInviteIDStmt *sql.Stmt
+}
+
+func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(inviteEventsSchema)
+ if err != nil {
+ return
+ }
+ if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil {
+ return
+ }
+ if s.selectLastInsertedInviteEventStmt, err = db.Prepare(selectLastInsertedInviteEventSQL); err != nil {
+ return
+ }
+ if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil {
+ return
+ }
+ if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil {
+ return
+ }
+ if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *inviteEventsStatements) insertInviteEvent(
+ ctx context.Context, inviteEvent gomatrixserverlib.Event,
+) (streamPos types.StreamPosition, err error) {
+ _, err = s.insertInviteEventStmt.ExecContext(
+ ctx,
+ inviteEvent.RoomID(),
+ inviteEvent.EventID(),
+ *inviteEvent.StateKey(),
+ inviteEvent.JSON(),
+ )
+ if err != nil {
+ return
+ }
+ err = s.selectLastInsertedInviteEventStmt.QueryRowContext(ctx).Scan(&streamPos)
+ return
+}
+
+func (s *inviteEventsStatements) deleteInviteEvent(
+ ctx context.Context, inviteEventID string,
+) error {
+ _, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID)
+ return err
+}
+
+// selectInviteEventsInRange returns a map of room ID to invite event for the
+// active invites for the target user ID in the supplied range.
+func (s *inviteEventsStatements) selectInviteEventsInRange(
+ ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition,
+) (map[string]gomatrixserverlib.Event, error) {
+ stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
+ rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ result := map[string]gomatrixserverlib.Event{}
+ for rows.Next() {
+ var (
+ roomID string
+ eventJSON []byte
+ )
+ if err = rows.Scan(&roomID, &eventJSON); err != nil {
+ return nil, err
+ }
+
+ event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false)
+ if err != nil {
+ return nil, err
+ }
+
+ result[roomID] = event
+ }
+ return result, nil
+}
+
+func (s *inviteEventsStatements) selectMaxInviteID(
+ ctx context.Context, txn *sql.Tx,
+) (id int64, err error) {
+ var nullableID sql.NullInt64
+ stmt := common.TxStmt(txn, s.selectMaxInviteIDStmt)
+ err = stmt.QueryRowContext(ctx).Scan(&nullableID)
+ if nullableID.Valid {
+ id = nullableID.Int64
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
new file mode 100644
index 00000000..8c01f2ce
--- /dev/null
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -0,0 +1,411 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "sort"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/types"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+const outputRoomEventsSchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_id TEXT NOT NULL UNIQUE,
+ room_id TEXT NOT NULL,
+ event_json TEXT NOT NULL,
+ type TEXT NOT NULL,
+ sender TEXT NOT NULL,
+ contains_url BOOL NOT NULL,
+ add_state_ids TEXT[],
+ remove_state_ids TEXT[],
+ session_id BIGINT,
+ transaction_id TEXT,
+ exclude_from_sync BOOL DEFAULT FALSE
+);
+`
+
+const insertEventSQL = "" +
+ "INSERT INTO syncapi_output_room_events (" +
+ "id, room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
+ ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
+ "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $11"
+
+const selectLastInsertedEventSQL = "" +
+ "SELECT id FROM syncapi_output_room_events WHERE rowid = last_insert_rowid()"
+
+const selectEventsSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
+
+const selectRecentEventsSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3" +
+ " ORDER BY id DESC LIMIT $4"
+
+const selectRecentEventsForSyncSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
+ " ORDER BY id DESC LIMIT $4"
+
+const selectEarlyEventsSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3" +
+ " ORDER BY id ASC LIMIT $4"
+
+const selectMaxEventIDSQL = "" +
+ "SELECT MAX(id) FROM syncapi_output_room_events"
+
+// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
+/*
+ $1 = oldPos,
+ $2 = newPos,
+ $3 = pq.StringArray(stateFilterPart.Senders),
+ $4 = pq.StringArray(stateFilterPart.NotSenders),
+ $5 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
+ $6 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
+ $7 = stateFilterPart.ContainsURL,
+ $8 = stateFilterPart.Limit,
+*/
+const selectStateInRangeSQL = "" +
+ "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
+ " FROM syncapi_output_room_events" +
+ " WHERE (id > $1 AND id <= $2)" + // old/new pos
+ " AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
+ /* " AND ( $3 IS NULL OR sender IN ($3) )" + // sender
+ " AND ( $4 IS NULL OR NOT(sender IN ($4)) )" + // not sender
+ " AND ( $5 IS NULL OR type IN ($5) )" + // type
+ " AND ( $6 IS NULL OR NOT(type IN ($6)) )" + // not type
+ " AND ( $7 IS NULL OR contains_url = $7)" + // contains URL? */
+ " ORDER BY id ASC" +
+ " LIMIT $8" // limit
+
+type outputRoomEventsStatements struct {
+ streamIDStatements *streamIDStatements
+ insertEventStmt *sql.Stmt
+ selectLastInsertedEventStmt *sql.Stmt
+ selectEventsStmt *sql.Stmt
+ selectMaxEventIDStmt *sql.Stmt
+ selectRecentEventsStmt *sql.Stmt
+ selectRecentEventsForSyncStmt *sql.Stmt
+ selectEarlyEventsStmt *sql.Stmt
+ selectStateInRangeStmt *sql.Stmt
+}
+
+func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(outputRoomEventsSchema)
+ if err != nil {
+ return
+ }
+ if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
+ return
+ }
+ if s.selectLastInsertedEventStmt, err = db.Prepare(selectLastInsertedEventSQL); err != nil {
+ return
+ }
+ if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil {
+ return
+ }
+ if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil {
+ return
+ }
+ if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
+ return
+ }
+ if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil {
+ return
+ }
+ if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil {
+ return
+ }
+ if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
+ return
+ }
+ return
+}
+
+// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
+// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
+// two positions, only the most recent state is returned.
+func (s *outputRoomEventsStatements) selectStateInRange(
+ ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
+ stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
+
+ rows, err := stmt.QueryContext(
+ ctx, oldPos, newPos,
+ /*pq.StringArray(stateFilterPart.Senders),
+ pq.StringArray(stateFilterPart.NotSenders),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
+ stateFilterPart.ContainsURL,*/
+ stateFilterPart.Limit,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+ // Fetch all the state change events for all rooms between the two positions then loop each event and:
+ // - Keep a cache of the event by ID (99% of state change events are for the event itself)
+ // - For each room ID, build up an array of event IDs which represents cumulative adds/removes
+ // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
+ // if they aren't in the event ID cache. We don't handle state deletion yet.
+ eventIDToEvent := make(map[string]types.StreamEvent)
+
+ // RoomID => A set (map[string]bool) of state event IDs which are between the two positions
+ stateNeeded := make(map[string]map[string]bool)
+
+ for rows.Next() {
+ var (
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ addIDs pq.StringArray
+ delIDs pq.StringArray
+ )
+ if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil {
+ return nil, nil, err
+ }
+ // Sanity check for deleted state and whine if we see it. We don't need to do anything
+ // since it'll just mark the event as not being needed.
+ if len(addIDs) < len(delIDs) {
+ log.WithFields(log.Fields{
+ "since": oldPos,
+ "current": newPos,
+ "adds": addIDs,
+ "dels": delIDs,
+ }).Warn("StateBetween: ignoring deleted state")
+ }
+
+ // TODO: Handle redacted events
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
+ if err != nil {
+ return nil, nil, err
+ }
+ needSet := stateNeeded[ev.RoomID()]
+ if needSet == nil { // make set if required
+ needSet = make(map[string]bool)
+ }
+ for _, id := range delIDs {
+ needSet[id] = false
+ }
+ for _, id := range addIDs {
+ needSet[id] = true
+ }
+ stateNeeded[ev.RoomID()] = needSet
+
+ eventIDToEvent[ev.EventID()] = types.StreamEvent{
+ Event: ev,
+ StreamPosition: streamPos,
+ ExcludeFromSync: excludeFromSync,
+ }
+ }
+
+ return stateNeeded, eventIDToEvent, nil
+}
+
+// MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied,
+// then this function should only ever be used at startup, as it will race with inserting events if it is
+// done afterwards. If there are no inserted events, 0 is returned.
+func (s *outputRoomEventsStatements) selectMaxEventID(
+ ctx context.Context, txn *sql.Tx,
+) (id int64, err error) {
+ var nullableID sql.NullInt64
+ stmt := common.TxStmt(txn, s.selectMaxEventIDStmt)
+ err = stmt.QueryRowContext(ctx).Scan(&nullableID)
+ if nullableID.Valid {
+ id = nullableID.Int64
+ }
+ return
+}
+
+// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
+// of the inserted event.
+func (s *outputRoomEventsStatements) insertEvent(
+ ctx context.Context, txn *sql.Tx,
+ event *gomatrixserverlib.Event, addState, removeState []string,
+ transactionID *api.TransactionID, excludeFromSync bool,
+) (streamPos types.StreamPosition, err error) {
+ var txnID *string
+ var sessionID *int64
+ if transactionID != nil {
+ sessionID = &transactionID.SessionID
+ txnID = &transactionID.TransactionID
+ }
+
+ // Parse content as JSON and search for an "url" key
+ containsURL := false
+ var content map[string]interface{}
+ if json.Unmarshal(event.Content(), &content) != nil {
+ // Set containsURL to true if url is present
+ _, containsURL = content["url"]
+ }
+
+ streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
+
+ insertStmt := common.TxStmt(txn, s.insertEventStmt)
+ selectStmt := common.TxStmt(txn, s.selectLastInsertedEventStmt)
+ _, err = insertStmt.ExecContext(
+ ctx,
+ streamPos,
+ event.RoomID(),
+ event.EventID(),
+ event.JSON(),
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ pq.StringArray(addState),
+ pq.StringArray(removeState),
+ sessionID,
+ txnID,
+ excludeFromSync,
+ )
+ if err != nil {
+ return
+ }
+ err = selectStmt.QueryRowContext(ctx).Scan(&streamPos)
+ return
+}
+
+// selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'.
+// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude
+// from sync.
+func (s *outputRoomEventsStatements) selectRecentEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomID string, fromPos, toPos types.StreamPosition, limit int,
+ chronologicalOrder bool, onlySyncEvents bool,
+) ([]types.StreamEvent, error) {
+ var stmt *sql.Stmt
+ if onlySyncEvents {
+ stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt)
+ } else {
+ stmt = common.TxStmt(txn, s.selectRecentEventsStmt)
+ }
+
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ events, err := rowsToStreamEvents(rows)
+ if err != nil {
+ return nil, err
+ }
+ if chronologicalOrder {
+ // The events need to be returned from oldest to latest, which isn't
+ // necessary the way the SQL query returns them, so a sort is necessary to
+ // ensure the events are in the right order in the slice.
+ sort.SliceStable(events, func(i int, j int) bool {
+ return events[i].StreamPosition < events[j].StreamPosition
+ })
+ }
+ return events, nil
+}
+
+// selectEarlyEvents returns the earliest events in the given room, starting
+// from a given position, up to a maximum of 'limit'.
+func (s *outputRoomEventsStatements) selectEarlyEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomID string, fromPos, toPos types.StreamPosition, limit int,
+) ([]types.StreamEvent, error) {
+ stmt := common.TxStmt(txn, s.selectEarlyEventsStmt)
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ events, err := rowsToStreamEvents(rows)
+ if err != nil {
+ return nil, err
+ }
+ // The events need to be returned from oldest to latest, which isn't
+ // necessarily the way the SQL query returns them, so a sort is necessary to
+ // ensure the events are in the right order in the slice.
+ sort.SliceStable(events, func(i int, j int) bool {
+ return events[i].StreamPosition < events[j].StreamPosition
+ })
+ return events, nil
+}
+
+// selectEvents returns the events for the given event IDs. If an event is
+// missing from the database, it will be omitted.
+func (s *outputRoomEventsStatements) selectEvents(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StreamEvent, error) {
+ var returnEvents []types.StreamEvent
+ stmt := common.TxStmt(txn, s.selectEventsStmt)
+ for _, eventID := range eventIDs {
+ rows, err := stmt.QueryContext(ctx, eventID)
+ if err != nil {
+ return nil, err
+ }
+ if streamEvents, err := rowsToStreamEvents(rows); err == nil {
+ returnEvents = append(returnEvents, streamEvents...)
+ }
+ rows.Close() // nolint: errcheck
+ }
+ return returnEvents, nil
+}
+
+func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
+ var result []types.StreamEvent
+ for rows.Next() {
+ var (
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ sessionID *int64
+ txnID *string
+ transactionID *api.TransactionID
+ )
+ if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
+ return nil, err
+ }
+ // TODO: Handle redacted events
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
+ if err != nil {
+ return nil, err
+ }
+
+ if sessionID != nil && txnID != nil {
+ transactionID = &api.TransactionID{
+ SessionID: *sessionID,
+ TransactionID: *txnID,
+ }
+ }
+
+ result = append(result, types.StreamEvent{
+ Event: ev,
+ StreamPosition: streamPos,
+ TransactionID: transactionID,
+ ExcludeFromSync: excludeFromSync,
+ })
+ }
+ return result, nil
+}
diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go
new file mode 100644
index 00000000..f7075bd6
--- /dev/null
+++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go
@@ -0,0 +1,192 @@
+// Copyright 2018 New Vector Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const outputRoomEventsTopologySchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
+ event_id TEXT PRIMARY KEY,
+ topological_position BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+
+ UNIQUE(topological_position, room_id)
+);
+-- The topological order will be used in events selection and ordering
+-- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id);
+`
+
+const insertEventInTopologySQL = "" +
+ "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" +
+ " VALUES ($1, $2, $3)" +
+ " ON CONFLICT (topological_position, room_id) DO UPDATE SET event_id = $1"
+
+const selectEventIDsInRangeASCSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
+ " ORDER BY topological_position ASC LIMIT $4"
+
+const selectEventIDsInRangeDESCSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
+ " ORDER BY topological_position DESC LIMIT $4"
+
+const selectPositionInTopologySQL = "" +
+ "SELECT topological_position FROM syncapi_output_room_events_topology" +
+ " WHERE event_id = $1"
+
+const selectMaxPositionInTopologySQL = "" +
+ "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1"
+
+const selectEventIDsFromPositionSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position = $2"
+
+type outputRoomEventsTopologyStatements struct {
+ insertEventInTopologyStmt *sql.Stmt
+ selectEventIDsInRangeASCStmt *sql.Stmt
+ selectEventIDsInRangeDESCStmt *sql.Stmt
+ selectPositionInTopologyStmt *sql.Stmt
+ selectMaxPositionInTopologyStmt *sql.Stmt
+ selectEventIDsFromPositionStmt *sql.Stmt
+}
+
+func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(outputRoomEventsTopologySchema)
+ if err != nil {
+ return
+ }
+ if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
+ return
+ }
+ if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
+ return
+ }
+ if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
+ return
+ }
+ if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
+ return
+ }
+ if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
+ return
+ }
+ if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil {
+ return
+ }
+ return
+}
+
+// insertEventInTopology inserts the given event in the room's topology, based
+// on the event's depth.
+func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
+ ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event,
+) (err error) {
+ stmt := common.TxStmt(txn, s.insertEventInTopologyStmt)
+ _, err = stmt.ExecContext(
+ ctx, event.EventID(), event.Depth(), event.RoomID(),
+ )
+ return
+}
+
+// selectEventIDsInRange selects the IDs of events which positions are within a
+// given range in a given room's topological order.
+// Returns an empty slice if no events match the given range.
+func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
+ ctx context.Context, txn *sql.Tx, roomID string,
+ fromPos, toPos types.StreamPosition,
+ limit int, chronologicalOrder bool,
+) (eventIDs []string, err error) {
+ // Decide on the selection's order according to whether chronological order
+ // is requested or not.
+ var stmt *sql.Stmt
+ if chronologicalOrder {
+ stmt = common.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
+ } else {
+ stmt = common.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
+ }
+
+ // Query the event IDs.
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err == sql.ErrNoRows {
+ // If no event matched the request, return an empty slice.
+ return []string{}, nil
+ } else if err != nil {
+ return
+ }
+
+ // Return the IDs.
+ var eventID string
+ for rows.Next() {
+ if err = rows.Scan(&eventID); err != nil {
+ return
+ }
+ eventIDs = append(eventIDs, eventID)
+ }
+
+ return
+}
+
+// selectPositionInTopology returns the position of a given event in the
+// topology of the room it belongs to.
+func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
+ ctx context.Context, txn *sql.Tx, eventID string,
+) (pos types.StreamPosition, err error) {
+ stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt)
+ err = stmt.QueryRowContext(ctx, eventID).Scan(&pos)
+ return
+}
+
+func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (pos types.StreamPosition, err error) {
+ stmt := common.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
+ err = stmt.QueryRowContext(ctx, roomID).Scan(&pos)
+ return
+}
+
+// selectEventIDsFromPosition returns the IDs of all events that have a given
+// position in the topology of a given room.
+func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition(
+ ctx context.Context, txn *sql.Tx, roomID string, pos types.StreamPosition,
+) (eventIDs []string, err error) {
+ // Query the event IDs.
+ stmt := common.TxStmt(txn, s.selectEventIDsFromPositionStmt)
+ rows, err := stmt.QueryContext(ctx, roomID, pos)
+ if err == sql.ErrNoRows {
+ // If no event matched the request, return an empty slice.
+ return []string{}, nil
+ } else if err != nil {
+ return
+ }
+ // Return the IDs.
+ var eventID string
+ for rows.Next() {
+ if err = rows.Scan(&eventID); err != nil {
+ return
+ }
+ eventIDs = append(eventIDs, eventID)
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
new file mode 100644
index 00000000..260f7a95
--- /dev/null
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -0,0 +1,58 @@
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+const streamIDTableSchema = `
+-- Global stream ID counter, used by other tables.
+CREATE TABLE IF NOT EXISTS syncapi_stream_id (
+ stream_name TEXT NOT NULL PRIMARY KEY,
+ stream_id INT DEFAULT 0,
+
+ UNIQUE(stream_name)
+);
+INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
+ ON CONFLICT DO NOTHING;
+`
+
+const increaseStreamIDStmt = "" +
+ "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
+
+const selectStreamIDStmt = "" +
+ "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
+
+type streamIDStatements struct {
+ increaseStreamIDStmt *sql.Stmt
+ selectStreamIDStmt *sql.Stmt
+}
+
+func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(streamIDTableSchema)
+ if err != nil {
+ return
+ }
+ if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
+ return
+ }
+ if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
+ return
+ }
+ return
+}
+
+func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+ increaseStmt := common.TxStmt(txn, s.increaseStreamIDStmt)
+ selectStmt := common.TxStmt(txn, s.selectStreamIDStmt)
+ if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
+ return
+ }
+ if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
+ return
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
new file mode 100644
index 00000000..8cfc1884
--- /dev/null
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -0,0 +1,1197 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/url"
+ "time"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/roomserver/api"
+
+ // Import the postgres database driver.
+ _ "github.com/lib/pq"
+ _ "github.com/mattn/go-sqlite3"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/typingserver/cache"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type stateDelta struct {
+ roomID string
+ stateEvents []gomatrixserverlib.Event
+ membership string
+ // The PDU stream position of the latest membership event for this user, if applicable.
+ // Can be 0 if there is no membership event in this delta.
+ membershipPos types.StreamPosition
+}
+
+// SyncServerDatasource represents a sync server datasource which manages
+// both the database for PDUs and caches for EDUs.
+type SyncServerDatasource struct {
+ db *sql.DB
+ common.PartitionOffsetStatements
+ streamID streamIDStatements
+ accountData accountDataStatements
+ events outputRoomEventsStatements
+ roomstate currentRoomStateStatements
+ invites inviteEventsStatements
+ typingCache *cache.TypingCache
+ topology outputRoomEventsTopologyStatements
+ backwardExtremities backwardExtremitiesStatements
+}
+
+// NewSyncServerDatasource creates a new sync server database
+// nolint: gocyclo
+func NewSyncServerDatasource(dataSourceName string) (*SyncServerDatasource, error) {
+ var d SyncServerDatasource
+ uri, err := url.Parse(dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ var cs string
+ if uri.Opaque != "" { // file:filename.db
+ cs = uri.Opaque
+ } else if uri.Path != "" { // file:///path/to/filename.db
+ cs = uri.Path
+ } else {
+ return nil, errors.New("no filename or path in connect string")
+ }
+ if d.db, err = sql.Open("sqlite3", cs); err != nil {
+ return nil, err
+ }
+ if err = d.prepare(); err != nil {
+ return nil, err
+ }
+ d.typingCache = cache.NewTypingCache()
+ return &d, nil
+}
+
+func (d *SyncServerDatasource) prepare() (err error) {
+ if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
+ return err
+ }
+ if err = d.streamID.prepare(d.db); err != nil {
+ return err
+ }
+ if err = d.accountData.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err = d.events.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err := d.roomstate.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err := d.invites.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err := d.topology.prepare(d.db); err != nil {
+ return err
+ }
+ if err := d.backwardExtremities.prepare(d.db); err != nil {
+ return err
+ }
+ return nil
+}
+
+// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
+func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
+ return d.roomstate.selectJoinedUsers(ctx)
+}
+
+// Events lookups a list of event by their event ID.
+// Returns a list of events matching the requested IDs found in the database.
+// If an event is not found in the database then it will be omitted from the list.
+// Returns an error if there was a problem talking with the database.
+// Does not include any transaction IDs in the returned events.
+func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
+ streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // We don't include a device here as we only include transaction IDs in
+ // incremental syncs.
+ return d.StreamEventsToEvents(nil, streamEvents), nil
+}
+
+func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.Event) error {
+ // If the event is already known as a backward extremity, don't consider
+ // it as such anymore now that we have it.
+ isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID())
+ if err != nil {
+ return err
+ }
+ if isBackwardExtremity {
+ if err = d.backwardExtremities.deleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+
+ // Check if we have all of the event's previous events. If an event is
+ // missing, add it to the room's backward extremities.
+ prevEvents, err := d.events.selectEvents(ctx, txn, ev.PrevEventIDs())
+ if err != nil {
+ return err
+ }
+ var found bool
+ for _, eID := range ev.PrevEventIDs() {
+ found = false
+ for _, prevEv := range prevEvents {
+ if eID == prevEv.EventID() {
+ found = true
+ }
+ }
+
+ // If the event is missing, consider it a backward extremity.
+ if !found {
+ if err = d.backwardExtremities.insertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
+// when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
+// Returns an error if there was a problem inserting this event.
+func (d *SyncServerDatasource) WriteEvent(
+ ctx context.Context,
+ ev *gomatrixserverlib.Event,
+ addStateEvents []gomatrixserverlib.Event,
+ addStateEventIDs, removeStateEventIDs []string,
+ transactionID *api.TransactionID, excludeFromSync bool,
+) (pduPosition types.StreamPosition, returnErr error) {
+ returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ pos, err := d.events.insertEvent(
+ ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
+ )
+ if err != nil {
+ fmt.Println("d.events.insertEvent:", err)
+ return err
+ }
+ pduPosition = pos
+
+ if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil {
+ fmt.Println("d.topology.insertEventInTopology:", err)
+ return err
+ }
+
+ if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
+ fmt.Println("d.handleBackwardExtremities:", err)
+ return err
+ }
+
+ if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
+ // Nothing to do, the event may have just been a message event.
+ fmt.Println("nothing to do")
+ return nil
+ }
+
+ return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition)
+ })
+
+ return pduPosition, returnErr
+}
+
+func (d *SyncServerDatasource) updateRoomState(
+ ctx context.Context, txn *sql.Tx,
+ removedEventIDs []string,
+ addedEvents []gomatrixserverlib.Event,
+ pduPosition types.StreamPosition,
+) error {
+ // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
+ for _, eventID := range removedEventIDs {
+ if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil {
+ return err
+ }
+ }
+
+ for _, event := range addedEvents {
+ if event.StateKey() == nil {
+ // ignore non state events
+ continue
+ }
+ var membership *string
+ if event.Type() == "m.room.member" {
+ value, err := event.Membership()
+ if err != nil {
+ return err
+ }
+ membership = &value
+ }
+ if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
+// If no event could be found, returns nil
+// If there was an issue during the retrieval, returns an error
+func (d *SyncServerDatasource) GetStateEvent(
+ ctx context.Context, roomID, evType, stateKey string,
+) (*gomatrixserverlib.Event, error) {
+ return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey)
+}
+
+// GetStateEventsForRoom fetches the state events for a given room.
+// Returns an empty slice if no state events could be found for this room.
+// Returns an error if there was an issue with the retrieval.
+func (d *SyncServerDatasource) GetStateEventsForRoom(
+ ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter,
+) (stateEvents []gomatrixserverlib.Event, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
+ return err
+ })
+ return
+}
+
+// GetEventsInRange retrieves all of the events on a given ordering using the
+// given extremities and limit.
+func (d *SyncServerDatasource) GetEventsInRange(
+ ctx context.Context,
+ from, to *types.PaginationToken,
+ roomID string, limit int,
+ backwardOrdering bool,
+) (events []types.StreamEvent, err error) {
+ // If the pagination token's type is types.PaginationTokenTypeTopology, the
+ // events must be retrieved from the rooms' topology table rather than the
+ // table contaning the syncapi server's whole stream of events.
+ if from.Type == types.PaginationTokenTypeTopology {
+ // Determine the backward and forward limit, i.e. the upper and lower
+ // limits to the selection in the room's topology, from the direction.
+ var backwardLimit, forwardLimit types.StreamPosition
+ if backwardOrdering {
+ // Backward ordering is antichronological (latest event to oldest
+ // one).
+ backwardLimit = to.PDUPosition
+ forwardLimit = from.PDUPosition
+ } else {
+ // Forward ordering is chronological (oldest event to latest one).
+ backwardLimit = from.PDUPosition
+ forwardLimit = to.PDUPosition
+ }
+
+ // Select the event IDs from the defined range.
+ var eIDs []string
+ eIDs, err = d.topology.selectEventIDsInRange(
+ ctx, nil, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering,
+ )
+ if err != nil {
+ return
+ }
+
+ // Retrieve the events' contents using their IDs.
+ events, err = d.events.selectEvents(ctx, nil, eIDs)
+ return
+ }
+
+ // If the pagination token's type is types.PaginationTokenTypeStream, the
+ // events must be retrieved from the table contaning the syncapi server's
+ // whole stream of events.
+
+ if backwardOrdering {
+ // When using backward ordering, we want the most recent events first.
+ if events, err = d.events.selectRecentEvents(
+ ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false,
+ ); err != nil {
+ return
+ }
+ } else {
+ // When using forward ordering, we want the least recent events first.
+ if events, err = d.events.selectEarlyEvents(
+ ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit,
+ ); err != nil {
+ return
+ }
+ }
+
+ return
+}
+
+// SyncPosition returns the latest positions for syncing.
+func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) {
+ return d.syncPositionTx(ctx, nil)
+}
+
+// BackwardExtremitiesForRoom returns the event IDs of all of the backward
+// extremities we know of for a given room.
+func (d *SyncServerDatasource) BackwardExtremitiesForRoom(
+ ctx context.Context, roomID string,
+) (backwardExtremities []string, err error) {
+ return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, nil, roomID)
+}
+
+// MaxTopologicalPosition returns the highest topological position for a given
+// room.
+func (d *SyncServerDatasource) MaxTopologicalPosition(
+ ctx context.Context, roomID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectMaxPositionInTopology(ctx, nil, roomID)
+}
+
+// EventsAtTopologicalPosition returns all of the events matching a given
+// position in the topology of a given room.
+func (d *SyncServerDatasource) EventsAtTopologicalPosition(
+ ctx context.Context, roomID string, pos types.StreamPosition,
+) ([]types.StreamEvent, error) {
+ eIDs, err := d.topology.selectEventIDsFromPosition(ctx, nil, roomID, pos)
+ if err != nil {
+ return nil, err
+ }
+
+ return d.events.selectEvents(ctx, nil, eIDs)
+}
+
+func (d *SyncServerDatasource) EventPositionInTopology(
+ ctx context.Context, eventID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectPositionInTopology(ctx, nil, eventID)
+}
+
+// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
+func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
+ return d.syncStreamPositionTx(ctx, nil)
+}
+
+func (d *SyncServerDatasource) syncStreamPositionTx(
+ ctx context.Context, txn *sql.Tx,
+) (types.StreamPosition, error) {
+ maxID, err := d.events.selectMaxEventID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxAccountDataID > maxID {
+ maxID = maxAccountDataID
+ }
+ maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxInviteID > maxID {
+ maxID = maxInviteID
+ }
+ return types.StreamPosition(maxID), nil
+}
+
+func (d *SyncServerDatasource) syncPositionTx(
+ ctx context.Context, txn *sql.Tx,
+) (sp types.PaginationToken, err error) {
+
+ maxEventID, err := d.events.selectMaxEventID(ctx, txn)
+ if err != nil {
+ return sp, err
+ }
+ maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
+ if err != nil {
+ return sp, err
+ }
+ if maxAccountDataID > maxEventID {
+ maxEventID = maxAccountDataID
+ }
+ maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
+ if err != nil {
+ return sp, err
+ }
+ if maxInviteID > maxEventID {
+ maxEventID = maxInviteID
+ }
+ sp.PDUPosition = types.StreamPosition(maxEventID)
+ sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition())
+ return
+}
+
+// addPDUDeltaToResponse adds all PDU deltas to a sync response.
+// IDs of all rooms the user joined are returned so EDU deltas can be added for them.
+func (d *SyncServerDatasource) addPDUDeltaToResponse(
+ ctx context.Context,
+ device authtypes.Device,
+ fromPos, toPos types.StreamPosition,
+ numRecentEventsPerRoom int,
+ wantFullState bool,
+ res *types.Response,
+) (joinedRoomIDs []string, err error) {
+ txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
+ if err != nil {
+ return nil, err
+ }
+ var succeeded bool
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
+
+ stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
+
+ // Work out which rooms to return in the response. This is done by getting not only the currently
+ // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions.
+ // This works out what the 'state' key should be for each room as well as which membership block
+ // to put the room into.
+ var deltas []stateDelta
+ if !wantFullState {
+ deltas, joinedRoomIDs, err = d.getStateDeltas(
+ ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart,
+ )
+ } else {
+ deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync(
+ ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart,
+ )
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ for _, delta := range deltas {
+ err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // TODO: This should be done in getStateDeltas
+ if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil {
+ return nil, err
+ }
+
+ succeeded = true
+ return joinedRoomIDs, nil
+}
+
+// addTypingDeltaToResponse adds all typing notifications to a sync response
+// since the specified position.
+func (d *SyncServerDatasource) addTypingDeltaToResponse(
+ since types.PaginationToken,
+ joinedRoomIDs []string,
+ res *types.Response,
+) error {
+ var jr types.JoinResponse
+ var ok bool
+ var err error
+ for _, roomID := range joinedRoomIDs {
+ if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
+ roomID, int64(since.EDUTypingPosition),
+ ); updated {
+ ev := gomatrixserverlib.ClientEvent{
+ Type: gomatrixserverlib.MTyping,
+ }
+ ev.Content, err = json.Marshal(map[string]interface{}{
+ "user_ids": typingUsers,
+ })
+ if err != nil {
+ return err
+ }
+
+ if jr, ok = res.Rooms.Join[roomID]; !ok {
+ jr = *types.NewJoinResponse()
+ }
+ jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev)
+ res.Rooms.Join[roomID] = jr
+ }
+ }
+ return nil
+}
+
+// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
+// the positions of that type are not equal in fromPos and toPos.
+func (d *SyncServerDatasource) addEDUDeltaToResponse(
+ fromPos, toPos types.PaginationToken,
+ joinedRoomIDs []string,
+ res *types.Response,
+) (err error) {
+
+ if fromPos.EDUTypingPosition != toPos.EDUTypingPosition {
+ err = d.addTypingDeltaToResponse(
+ fromPos, joinedRoomIDs, res,
+ )
+ }
+
+ return
+}
+
+// IncrementalSync returns all the data needed in order to create an incremental
+// sync response for the given user. Events returned will include any client
+// transaction IDs associated with the given device. These transaction IDs come
+// from when the device sent the event via an API that included a transaction
+// ID.
+func (d *SyncServerDatasource) IncrementalSync(
+ ctx context.Context,
+ device authtypes.Device,
+ fromPos, toPos types.PaginationToken,
+ numRecentEventsPerRoom int,
+ wantFullState bool,
+) (*types.Response, error) {
+ nextBatchPos := fromPos.WithUpdates(toPos)
+ res := types.NewResponse(nextBatchPos)
+
+ var joinedRoomIDs []string
+ var err error
+ if fromPos.PDUPosition != toPos.PDUPosition || wantFullState {
+ joinedRoomIDs, err = d.addPDUDeltaToResponse(
+ ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res,
+ )
+ } else {
+ joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(
+ ctx, nil, device.UserID, gomatrixserverlib.Join,
+ )
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ err = d.addEDUDeltaToResponse(
+ fromPos, toPos, joinedRoomIDs, res,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed
+// to it. It returns toPos and joinedRoomIDs for use of adding EDUs.
+func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
+ ctx context.Context,
+ userID string,
+ numRecentEventsPerRoom int,
+) (
+ res *types.Response,
+ toPos types.PaginationToken,
+ joinedRoomIDs []string,
+ err error,
+) {
+ // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have
+ // a consistent view of the database throughout. This includes extracting the sync position.
+ // This does have the unfortunate side-effect that all the matrixy logic resides in this function,
+ // but it's better to not hide the fact that this is being done in a transaction.
+ txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
+ if err != nil {
+ return
+ }
+ var succeeded bool
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
+
+ // Get the current sync position which we will base the sync response on.
+ toPos, err = d.syncPositionTx(ctx, txn)
+ if err != nil {
+ return
+ }
+
+ res = types.NewResponse(toPos)
+
+ // Extract room state and recent events for all rooms the user is joined to.
+ joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
+ if err != nil {
+ return
+ }
+ fmt.Println("Joined rooms:", joinedRoomIDs)
+
+ stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
+
+ // Build up a /sync response. Add joined rooms.
+ for _, roomID := range joinedRoomIDs {
+ fmt.Println("WE'RE ON", roomID)
+
+ var stateEvents []gomatrixserverlib.Event
+ stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart)
+ if err != nil {
+ fmt.Println("d.roomstate.selectCurrentState:", err)
+ return
+ }
+ //fmt.Println("State events:", stateEvents)
+ // TODO: When filters are added, we may need to call this multiple times to get enough events.
+ // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
+ var recentStreamEvents []types.StreamEvent
+ recentStreamEvents, err = d.events.selectRecentEvents(
+ ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition,
+ numRecentEventsPerRoom, true, true,
+ )
+ if err != nil {
+ fmt.Println("d.events.selectRecentEvents:", err)
+ return
+ }
+ //fmt.Println("Recent stream events:", recentStreamEvents)
+
+ // Retrieve the backward topology position, i.e. the position of the
+ // oldest event in the room's topology.
+ var backwardTopologyPos types.StreamPosition
+ backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
+ if err != nil {
+ fmt.Println("d.topology.selectPositionInTopology:", err)
+ return nil, types.PaginationToken{}, []string{}, err
+ }
+ fmt.Println("Backward topology position:", backwardTopologyPos)
+ if backwardTopologyPos-1 <= 0 {
+ backwardTopologyPos = types.StreamPosition(1)
+ } else {
+ backwardTopologyPos--
+ }
+
+ // We don't include a device here as we don't need to send down
+ // transaction IDs for complete syncs
+ recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
+ stateEvents = removeDuplicates(stateEvents, recentEvents)
+ jr := types.NewJoinResponse()
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
+ jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
+ jr.Timeline.Limited = true
+ jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
+ res.Rooms.Join[roomID] = *jr
+ }
+
+ if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil {
+ fmt.Println("d.addInvitesToResponse:", err)
+ return
+ }
+
+ succeeded = true
+ return res, toPos, joinedRoomIDs, err
+}
+
+// CompleteSync returns a complete /sync API response for the given user.
+func (d *SyncServerDatasource) CompleteSync(
+ ctx context.Context, userID string, numRecentEventsPerRoom int,
+) (*types.Response, error) {
+ res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
+ ctx, userID, numRecentEventsPerRoom,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Use a zero value SyncPosition for fromPos so all EDU states are added.
+ err = d.addEDUDeltaToResponse(
+ types.PaginationToken{}, toPos, joinedRoomIDs, res,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+var txReadOnlySnapshot = sql.TxOptions{
+ // Set the isolation level so that we see a snapshot of the database.
+ // In PostgreSQL repeatable read transactions will see a snapshot taken
+ // at the first query, and since the transaction is read-only it can't
+ // run into any serialisation errors.
+ // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
+ Isolation: sql.LevelRepeatableRead,
+ ReadOnly: true,
+}
+
+// GetAccountDataInRange returns all account data for a given user inserted or
+// updated between two given positions
+// Returns a map following the format data[roomID] = []dataTypes
+// If no data is retrieved, returns an empty map
+// If there was an issue with the retrieval, returns an error
+func (d *SyncServerDatasource) GetAccountDataInRange(
+ ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
+ accountDataFilterPart *gomatrixserverlib.EventFilter,
+) (map[string][]string, error) {
+ return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
+}
+
+// UpsertAccountData keeps track of new or updated account data, by saving the type
+// of the new/updated data, and the user ID and room ID the data is related to (empty)
+// room ID means the data isn't specific to any room)
+// If no data with the given type, user ID and room ID exists in the database,
+// creates a new row, else update the existing one
+// Returns an error if there was an issue with the upsert
+func (d *SyncServerDatasource) UpsertAccountData(
+ ctx context.Context, userID, roomID, dataType string,
+) (sp types.StreamPosition, err error) {
+ txn, err := d.db.BeginTx(ctx, nil)
+ if err != nil {
+ return types.StreamPosition(0), err
+ }
+ var succeeded bool
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
+ sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType)
+ return
+}
+
+// AddInviteEvent stores a new invite event for a user.
+// If the invite was successfully stored this returns the stream ID it was stored at.
+// Returns an error if there was a problem communicating with the database.
+func (d *SyncServerDatasource) AddInviteEvent(
+ ctx context.Context, inviteEvent gomatrixserverlib.Event,
+) (types.StreamPosition, error) {
+ return d.invites.insertInviteEvent(ctx, inviteEvent)
+}
+
+// RetireInviteEvent removes an old invite event from the database.
+// Returns an error if there was a problem communicating with the database.
+func (d *SyncServerDatasource) RetireInviteEvent(
+ ctx context.Context, inviteEventID string,
+) error {
+ // TODO: Record that invite has been retired in a stream so that we can
+ // notify the user in an incremental sync.
+ err := d.invites.deleteInviteEvent(ctx, inviteEventID)
+ return err
+}
+
+func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
+ d.typingCache.SetTimeoutCallback(fn)
+}
+
+// AddTypingUser adds a typing user to the typing cache.
+// Returns the newly calculated sync position for typing notifications.
+func (d *SyncServerDatasource) AddTypingUser(
+ userID, roomID string, expireTime *time.Time,
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime))
+}
+
+// RemoveTypingUser removes a typing user from the typing cache.
+// Returns the newly calculated sync position for typing notifications.
+func (d *SyncServerDatasource) RemoveTypingUser(
+ userID, roomID string,
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID))
+}
+
+func (d *SyncServerDatasource) addInvitesToResponse(
+ ctx context.Context, txn *sql.Tx,
+ userID string,
+ fromPos, toPos types.StreamPosition,
+ res *types.Response,
+) error {
+ invites, err := d.invites.selectInviteEventsInRange(
+ ctx, txn, userID, fromPos, toPos,
+ )
+ if err != nil {
+ return err
+ }
+ for roomID, inviteEvent := range invites {
+ ir := types.NewInviteResponse()
+ ir.InviteState.Events = gomatrixserverlib.ToClientEvents(
+ []gomatrixserverlib.Event{inviteEvent}, gomatrixserverlib.FormatSync,
+ )
+ // TODO: add the invite state from the invite event.
+ res.Rooms.Invite[roomID] = *ir
+ }
+ return nil
+}
+
+// Retrieve the backward topology position, i.e. the position of the
+// oldest event in the room's topology.
+func (d *SyncServerDatasource) getBackwardTopologyPos(
+ ctx context.Context, txn *sql.Tx,
+ events []types.StreamEvent,
+) (pos types.StreamPosition) {
+ if len(events) > 0 {
+ pos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID())
+ }
+ if pos-1 <= 0 {
+ pos = types.StreamPosition(1)
+ } else {
+ pos = pos - 1
+ }
+ return
+}
+
+// addRoomDeltaToResponse adds a room state delta to a sync response
+func (d *SyncServerDatasource) addRoomDeltaToResponse(
+ ctx context.Context,
+ device *authtypes.Device,
+ txn *sql.Tx,
+ fromPos, toPos types.StreamPosition,
+ delta stateDelta,
+ numRecentEventsPerRoom int,
+ res *types.Response,
+) error {
+ endPos := toPos
+ if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave {
+ // make sure we don't leak recent events after the leave event.
+ // TODO: History visibility makes this somewhat complex to handle correctly. For example:
+ // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
+ // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave
+ // in a single /sync request
+ // This is all "okay" assuming history_visibility == "shared" which it is by default.
+ endPos = delta.membershipPos
+ }
+ recentStreamEvents, err := d.events.selectRecentEvents(
+ ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos),
+ numRecentEventsPerRoom, true, true,
+ )
+ if err != nil {
+ return err
+ }
+ recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
+ delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
+ backwardTopologyPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents)
+
+ switch delta.membership {
+ case gomatrixserverlib.Join:
+ jr := types.NewJoinResponse()
+
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
+ jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
+ jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
+ jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
+ res.Rooms.Join[delta.roomID] = *jr
+ case gomatrixserverlib.Leave:
+ fallthrough // transitions to leave are the same as ban
+ case gomatrixserverlib.Ban:
+ // TODO: recentEvents may contain events that this user is not allowed to see because they are
+ // no longer in the room.
+ lr := types.NewLeaveResponse()
+ lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
+ lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
+ lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
+ lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
+ res.Rooms.Leave[delta.roomID] = *lr
+ }
+
+ return nil
+}
+
+// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
+// Returns a map of room ID to list of events.
+func (d *SyncServerDatasource) fetchStateEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomIDToEventIDSet map[string]map[string]bool,
+ eventIDToEvent map[string]types.StreamEvent,
+) (map[string][]types.StreamEvent, error) {
+ stateBetween := make(map[string][]types.StreamEvent)
+ missingEvents := make(map[string][]string)
+ for roomID, ids := range roomIDToEventIDSet {
+ events := stateBetween[roomID]
+ for id, need := range ids {
+ if !need {
+ continue // deleted state
+ }
+ e, ok := eventIDToEvent[id]
+ if ok {
+ events = append(events, e)
+ } else {
+ m := missingEvents[roomID]
+ m = append(m, id)
+ missingEvents[roomID] = m
+ }
+ }
+ stateBetween[roomID] = events
+ }
+
+ if len(missingEvents) > 0 {
+ // This happens when add_state_ids has an event ID which is not in the provided range.
+ // We need to explicitly fetch them.
+ allMissingEventIDs := []string{}
+ for _, missingEvIDs := range missingEvents {
+ allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
+ }
+ evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
+ if err != nil {
+ return nil, err
+ }
+ // we know we got them all otherwise an error would've been returned, so just loop the events
+ for _, ev := range evs {
+ roomID := ev.RoomID()
+ stateBetween[roomID] = append(stateBetween[roomID], ev)
+ }
+ }
+ return stateBetween, nil
+}
+
+func (d *SyncServerDatasource) fetchMissingStateEvents(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StreamEvent, error) {
+ // Fetch from the events table first so we pick up the stream ID for the
+ // event.
+ events, err := d.events.selectEvents(ctx, txn, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ have := map[string]bool{}
+ for _, event := range events {
+ have[event.EventID()] = true
+ }
+ var missing []string
+ for _, eventID := range eventIDs {
+ if !have[eventID] {
+ missing = append(missing, eventID)
+ }
+ }
+ if len(missing) == 0 {
+ return events, nil
+ }
+
+ // If they are missing from the events table then they should be state
+ // events that we received from outside the main event stream.
+ // These should be in the room state table.
+ stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing)
+
+ if err != nil {
+ return nil, err
+ }
+ if len(stateEvents) != len(missing) {
+ return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
+ }
+ events = append(events, stateEvents...)
+ return events, nil
+}
+
+// getStateDeltas returns the state deltas between fromPos and toPos,
+// exclusive of oldPos, inclusive of newPos, for the rooms in which
+// the user has new membership events.
+// A list of joined room IDs is also returned in case the caller needs it.
+func (d *SyncServerDatasource) getStateDeltas(
+ ctx context.Context, device *authtypes.Device, txn *sql.Tx,
+ fromPos, toPos types.StreamPosition, userID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]stateDelta, []string, error) {
+ // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
+ // - Get membership list changes for this user in this sync response
+ // - For each room which has membership list changes:
+ // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
+ // If it is, then we need to send the full room state down (and 'limited' is always true).
+ // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
+ // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
+ // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
+ var deltas []stateDelta
+
+ // get all the state events ever between these two positions
+ stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart)
+ if err != nil {
+ return nil, nil, err
+ }
+ state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ for roomID, stateStreamEvents := range state {
+ for _, ev := range stateStreamEvents {
+ // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
+ // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this,
+ // dupe join events will result in the entire room state coming down to the client again. This is added in
+ // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
+ // the timeline.
+ if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
+ if membership == gomatrixserverlib.Join {
+ // send full room state down instead of a delta
+ var s []types.StreamEvent
+ s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart)
+ if err != nil {
+ return nil, nil, err
+ }
+ state[roomID] = s
+ continue // we'll add this room in when we do joined rooms
+ }
+
+ deltas = append(deltas, stateDelta{
+ membership: membership,
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
+ roomID: roomID,
+ })
+ break
+ }
+ }
+ }
+
+ // Add in currently joined rooms
+ joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
+ if err != nil {
+ return nil, nil, err
+ }
+ for _, joinedRoomID := range joinedRoomIDs {
+ deltas = append(deltas, stateDelta{
+ membership: gomatrixserverlib.Join,
+ stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
+ roomID: joinedRoomID,
+ })
+ }
+
+ return deltas, joinedRoomIDs, nil
+}
+
+// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
+// requests with full_state=true.
+// Fetches full state for all joined rooms and uses selectStateInRange to get
+// updates for other rooms.
+func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
+ ctx context.Context, device *authtypes.Device, txn *sql.Tx,
+ fromPos, toPos types.StreamPosition, userID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]stateDelta, []string, error) {
+ joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Use a reasonable initial capacity
+ deltas := make([]stateDelta, 0, len(joinedRoomIDs))
+
+ // Add full states for all joined rooms
+ for _, joinedRoomID := range joinedRoomIDs {
+ s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilterPart)
+ if stateErr != nil {
+ return nil, nil, stateErr
+ }
+ deltas = append(deltas, stateDelta{
+ membership: gomatrixserverlib.Join,
+ stateEvents: d.StreamEventsToEvents(device, s),
+ roomID: joinedRoomID,
+ })
+ }
+
+ // Get all the state events ever between these two positions
+ stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart)
+ if err != nil {
+ return nil, nil, err
+ }
+ state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ for roomID, stateStreamEvents := range state {
+ for _, ev := range stateStreamEvents {
+ if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
+ if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
+ deltas = append(deltas, stateDelta{
+ membership: membership,
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
+ roomID: roomID,
+ })
+ }
+
+ break
+ }
+ }
+ }
+
+ return deltas, joinedRoomIDs, nil
+}
+
+func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
+ ctx context.Context, txn *sql.Tx, roomID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]types.StreamEvent, error) {
+ allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
+ if err != nil {
+ return nil, err
+ }
+ s := make([]types.StreamEvent, len(allState))
+ for i := 0; i < len(s); i++ {
+ s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0}
+ }
+ return s, nil
+}
+
+// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
+// matches the streamevent.transactionID device then the transaction ID gets
+// added to the unsigned section of the output event.
+func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event {
+ out := make([]gomatrixserverlib.Event, len(in))
+ for i := 0; i < len(in); i++ {
+ out[i] = in[i].Event
+ if device != nil && in[i].TransactionID != nil {
+ if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
+ err := out[i].SetUnsignedField(
+ "transaction_id", in[i].TransactionID.TransactionID,
+ )
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ "event_id": out[i].EventID(),
+ }).WithError(err).Warnf("Failed to add transaction ID to event")
+ }
+ }
+ }
+ }
+ return out
+}
+
+// There may be some overlap where events in stateEvents are already in recentEvents, so filter
+// them out so we don't include them twice in the /sync response. They should be in recentEvents
+// only, so clients get to the correct state once they have rolled forward.
+func removeDuplicates(stateEvents, recentEvents []gomatrixserverlib.Event) []gomatrixserverlib.Event {
+ for _, recentEv := range recentEvents {
+ if recentEv.StateKey() == nil {
+ continue // not a state event
+ }
+ // TODO: This is a linear scan over all the current state events in this room. This will
+ // be slow for big rooms. We should instead sort the state events by event ID (ORDER BY)
+ // then do a binary search to find matching events, similar to what roomserver does.
+ for j := 0; j < len(stateEvents); j++ {
+ if stateEvents[j].EventID() == recentEv.EventID() {
+ // overwrite the element to remove with the last element then pop the last element.
+ // This is orders of magnitude faster than re-slicing, but doesn't preserve ordering
+ // (we don't care about the order of stateEvents)
+ stateEvents[j] = stateEvents[len(stateEvents)-1]
+ stateEvents = stateEvents[:len(stateEvents)-1]
+ break // there shouldn't be multiple events with the same event ID
+ }
+ }
+ }
+ return stateEvents
+}
+
+// getMembershipFromEvent returns the value of content.membership iff the event is a state event
+// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
+func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
+ if ev.Type() == "m.room.member" && ev.StateKeyEquals(userID) {
+ membership, err := ev.Membership()
+ if err != nil {
+ return ""
+ }
+ return membership
+ }
+ return ""
+}
diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go
index e6392844..c87024b2 100644
--- a/syncapi/storage/storage.go
+++ b/syncapi/storage/storage.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/typingserver/cache"
"github.com/matrix-org/gomatrixserverlib"
@@ -63,6 +64,8 @@ func NewSyncServerDatasource(dataSourceName string) (Database, error) {
switch uri.Scheme {
case "postgres":
return postgres.NewSyncServerDatasource(dataSourceName)
+ case "file":
+ return sqlite3.NewSyncServerDatasource(dataSourceName)
default:
return postgres.NewSyncServerDatasource(dataSourceName)
}
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index 06a8d6d8..22bd239f 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -32,12 +32,12 @@ import (
// RequestPool manages HTTP long-poll connections for /sync
type RequestPool struct {
db storage.Database
- accountDB *accounts.Database
+ accountDB accounts.Database
notifier *Notifier
}
// NewRequestPool makes a new RequestPool
-func NewRequestPool(db storage.Database, n *Notifier, adb *accounts.Database) *RequestPool {
+func NewRequestPool(db storage.Database, n *Notifier, adb accounts.Database) *RequestPool {
return &RequestPool{db, adb, n}
}
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index ecf532ca..1535d2b1 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -36,8 +36,8 @@ import (
// component.
func SetupSyncAPIComponent(
base *basecomponent.BaseDendrite,
- deviceDB *devices.Database,
- accountsDB *accounts.Database,
+ deviceDB devices.Database,
+ accountsDB accounts.Database,
queryAPI api.RoomserverQueryAPI,
federation *gomatrixserverlib.FederationClient,
cfg *config.Dendrite,