aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/routing/routing.go2
-rw-r--r--syncapi/storage/postgres/syncserver.go17
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go143
-rw-r--r--syncapi/storage/sqlite3/backward_extremities_table.go124
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go276
-rw-r--r--syncapi/storage/sqlite3/filtering.go36
-rw-r--r--syncapi/storage/sqlite3/invites_table.go157
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go411
-rw-r--r--syncapi/storage/sqlite3/output_room_events_topology_table.go192
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go58
-rw-r--r--syncapi/storage/sqlite3/syncserver.go1197
-rw-r--r--syncapi/storage/storage.go3
-rw-r--r--syncapi/sync/requestpool.go4
-rw-r--r--syncapi/syncapi.go4
14 files changed, 2615 insertions, 9 deletions
diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go
index 8916565d..be90e0a0 100644
--- a/syncapi/routing/routing.go
+++ b/syncapi/routing/routing.go
@@ -39,7 +39,7 @@ const pathPrefixR0 = "/_matrix/client/r0"
// nolint: gocyclo
func Setup(
apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database,
- deviceDB *devices.Database, federation *gomatrixserverlib.FederationClient,
+ deviceDB devices.Database, federation *gomatrixserverlib.FederationClient,
queryAPI api.RoomserverQueryAPI,
cfg *config.Dendrite,
) {
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index aec37185..6a33a8b4 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -413,13 +413,18 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
numRecentEventsPerRoom int,
wantFullState bool,
res *types.Response,
-) ([]string, error) {
+) (joinedRoomIDs []string, err error) {
txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
if err != nil {
return nil, err
}
var succeeded bool
- defer common.EndTransaction(txn, &succeeded)
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
stateFilter := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
@@ -428,7 +433,6 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// This works out what the 'state' key should be for each room as well as which membership block
// to put the room into.
var deltas []stateDelta
- var joinedRoomIDs []string
if !wantFullState {
deltas, joinedRoomIDs, err = d.getStateDeltas(
ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter,
@@ -570,7 +574,12 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
return
}
var succeeded bool
- defer common.EndTransaction(txn, &succeeded)
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
// Get the current sync position which we will base the sync response on.
toPos, err = d.syncPositionTx(ctx, txn)
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
new file mode 100644
index 00000000..3274e66e
--- /dev/null
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -0,0 +1,143 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const accountDataSchema = `
+CREATE TABLE IF NOT EXISTS syncapi_account_data_type (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ UNIQUE (user_id, room_id, type)
+);
+`
+
+const insertAccountDataSQL = "" +
+ "INSERT INTO syncapi_account_data_type (id, user_id, room_id, type) VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id, room_id, type) DO UPDATE" +
+ " SET id = EXCLUDED.id"
+
+const selectAccountDataInRangeSQL = "" +
+ "SELECT room_id, type FROM syncapi_account_data_type" +
+ " WHERE user_id = $1 AND id > $2 AND id <= $3" +
+ " AND ( $4 IS NULL OR type IN ($4) )" +
+ " AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
+ " ORDER BY id ASC LIMIT $6"
+
+const selectMaxAccountDataIDSQL = "" +
+ "SELECT MAX(id) FROM syncapi_account_data_type"
+
+type accountDataStatements struct {
+ streamIDStatements *streamIDStatements
+ insertAccountDataStmt *sql.Stmt
+ selectAccountDataInRangeStmt *sql.Stmt
+ selectMaxAccountDataIDStmt *sql.Stmt
+}
+
+func (s *accountDataStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(accountDataSchema)
+ if err != nil {
+ return
+ }
+ if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil {
+ return
+ }
+ if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil {
+ return
+ }
+ if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *accountDataStatements) insertAccountData(
+ ctx context.Context, txn *sql.Tx,
+ userID, roomID, dataType string,
+) (pos types.StreamPosition, err error) {
+ pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
+ insertStmt := common.TxStmt(txn, s.insertAccountDataStmt)
+ _, err = insertStmt.ExecContext(ctx, pos, userID, roomID, dataType)
+ return
+}
+
+func (s *accountDataStatements) selectAccountDataInRange(
+ ctx context.Context,
+ userID string,
+ oldPos, newPos types.StreamPosition,
+ accountDataFilterPart *gomatrixserverlib.EventFilter,
+) (data map[string][]string, err error) {
+ data = make(map[string][]string)
+
+ // If both positions are the same, it means that the data was saved after the
+ // latest room event. In that case, we need to decrement the old position as
+ // it would prevent the SQL request from returning anything.
+ if oldPos == newPos {
+ oldPos--
+ }
+
+ rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos,
+ pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.Types)),
+ pq.StringArray(filterConvertTypeWildcardToSQL(accountDataFilterPart.NotTypes)),
+ accountDataFilterPart.Limit,
+ )
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var dataType string
+ var roomID string
+
+ if err = rows.Scan(&roomID, &dataType); err != nil {
+ return
+ }
+
+ if len(data[roomID]) > 0 {
+ data[roomID] = append(data[roomID], dataType)
+ } else {
+ data[roomID] = []string{dataType}
+ }
+ }
+
+ return
+}
+
+func (s *accountDataStatements) selectMaxAccountDataID(
+ ctx context.Context, txn *sql.Tx,
+) (id int64, err error) {
+ var nullableID sql.NullInt64
+ stmt := common.TxStmt(txn, s.selectMaxAccountDataIDStmt)
+ err = stmt.QueryRowContext(ctx).Scan(&nullableID)
+ if nullableID.Valid {
+ id = nullableID.Int64
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/backward_extremities_table.go b/syncapi/storage/sqlite3/backward_extremities_table.go
new file mode 100644
index 00000000..fcf15da2
--- /dev/null
+++ b/syncapi/storage/sqlite3/backward_extremities_table.go
@@ -0,0 +1,124 @@
+// Copyright 2018 New Vector Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+)
+
+const backwardExtremitiesSchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_backward_extremities (
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+
+ PRIMARY KEY(room_id, event_id)
+);
+`
+
+const insertBackwardExtremitySQL = "" +
+ "INSERT INTO syncapi_backward_extremities (room_id, event_id)" +
+ " VALUES ($1, $2)" +
+ " ON CONFLICT (room_id, event_id) DO NOTHING"
+
+const selectBackwardExtremitiesForRoomSQL = "" +
+ "SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1"
+
+const isBackwardExtremitySQL = "" +
+ "SELECT EXISTS (" +
+ " SELECT TRUE FROM syncapi_backward_extremities" +
+ " WHERE room_id = $1 AND event_id = $2" +
+ ")"
+
+const deleteBackwardExtremitySQL = "" +
+ "DELETE FROM syncapi_backward_extremities" +
+ " WHERE room_id = $1 AND event_id = $2"
+
+type backwardExtremitiesStatements struct {
+ insertBackwardExtremityStmt *sql.Stmt
+ selectBackwardExtremitiesForRoomStmt *sql.Stmt
+ isBackwardExtremityStmt *sql.Stmt
+ deleteBackwardExtremityStmt *sql.Stmt
+}
+
+func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(backwardExtremitiesSchema)
+ if err != nil {
+ return
+ }
+ if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
+ return
+ }
+ if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
+ return
+ }
+ if s.isBackwardExtremityStmt, err = db.Prepare(isBackwardExtremitySQL); err != nil {
+ return
+ }
+ if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *backwardExtremitiesStatements) insertsBackwardExtremity(
+ ctx context.Context, txn *sql.Tx, roomID, eventID string,
+) (err error) {
+ stmt := common.TxStmt(txn, s.insertBackwardExtremityStmt)
+ _, err = stmt.ExecContext(ctx, roomID, eventID)
+ return
+}
+
+func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (eventIDs []string, err error) {
+ eventIDs = make([]string, 0)
+
+ stmt := common.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt)
+ rows, err := stmt.QueryContext(ctx, roomID)
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var eID string
+ if err = rows.Scan(&eID); err != nil {
+ return
+ }
+
+ eventIDs = append(eventIDs, eID)
+ }
+
+ return
+}
+
+func (s *backwardExtremitiesStatements) isBackwardExtremity(
+ ctx context.Context, txn *sql.Tx, roomID, eventID string,
+) (isBE bool, err error) {
+ stmt := common.TxStmt(txn, s.isBackwardExtremityStmt)
+ err = stmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE)
+ return
+}
+
+func (s *backwardExtremitiesStatements) deleteBackwardExtremity(
+ ctx context.Context, txn *sql.Tx, roomID, eventID string,
+) (err error) {
+ stmt := common.TxStmt(txn, s.deleteBackwardExtremityStmt)
+ _, err = stmt.ExecContext(ctx, roomID, eventID)
+ return
+}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
new file mode 100644
index 00000000..4ce94666
--- /dev/null
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -0,0 +1,276 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const currentRoomStateSchema = `
+-- Stores the current room state for every room.
+CREATE TABLE IF NOT EXISTS syncapi_current_room_state (
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ sender TEXT NOT NULL,
+ contains_url BOOL NOT NULL DEFAULT false,
+ state_key TEXT NOT NULL,
+ event_json TEXT NOT NULL,
+ membership TEXT,
+ added_at BIGINT,
+ UNIQUE (room_id, type, state_key)
+);
+-- for event deletion
+CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_state(event_id, room_id, type, sender, contains_url);
+-- for querying membership states of users
+-- CREATE INDEX IF NOT EXISTS syncapi_membership_idx ON syncapi_current_room_state(type, state_key, membership) WHERE membership IS NOT NULL AND membership != 'leave';
+`
+
+const upsertRoomStateSQL = "" +
+ "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, event_json, membership, added_at)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
+ " ON CONFLICT (event_id, room_id, type, sender, contains_url)" +
+ " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, event_json = $7, membership = $8, added_at = $9"
+
+const deleteRoomStateByEventIDSQL = "" +
+ "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
+
+const selectRoomIDsWithMembershipSQL = "" +
+ "SELECT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
+
+const selectCurrentStateSQL = "" +
+ "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1" +
+ " AND ( $2 IS NULL OR sender IN ($2) )" +
+ " AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
+ " AND ( $4 IS NULL OR type IN ($4) )" +
+ " AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
+ " AND ( $6 IS NULL OR contains_url = $6 )" +
+ " LIMIT $7"
+
+const selectJoinedUsersSQL = "" +
+ "SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
+
+const selectStateEventSQL = "" +
+ "SELECT event_json FROM syncapi_current_room_state WHERE room_id = $1 AND type = $2 AND state_key = $3"
+
+const selectEventsWithEventIDsSQL = "" +
+ // TODO: The session_id and transaction_id blanks are here because otherwise
+ // the rowsToStreamEvents expects there to be exactly five columns. We need to
+ // figure out if these really need to be in the DB, and if so, we need a
+ // better permanent fix for this. - neilalexander, 2 Jan 2020
+ "SELECT added_at, event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
+ " FROM syncapi_current_room_state WHERE event_id IN ($1)"
+
+type currentRoomStateStatements struct {
+ streamIDStatements *streamIDStatements
+ upsertRoomStateStmt *sql.Stmt
+ deleteRoomStateByEventIDStmt *sql.Stmt
+ selectRoomIDsWithMembershipStmt *sql.Stmt
+ selectCurrentStateStmt *sql.Stmt
+ selectJoinedUsersStmt *sql.Stmt
+ selectEventsWithEventIDsStmt *sql.Stmt
+ selectStateEventStmt *sql.Stmt
+}
+
+func (s *currentRoomStateStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(currentRoomStateSchema)
+ if err != nil {
+ return
+ }
+ if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
+ return
+ }
+ if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
+ return
+ }
+ if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
+ return
+ }
+ if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
+ return
+ }
+ if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
+ return
+ }
+ if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil {
+ return
+ }
+ if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil {
+ return
+ }
+ return
+}
+
+// JoinedMemberLists returns a map of room ID to a list of joined user IDs.
+func (s *currentRoomStateStatements) selectJoinedUsers(
+ ctx context.Context,
+) (map[string][]string, error) {
+ rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ result := make(map[string][]string)
+ for rows.Next() {
+ var roomID string
+ var userID string
+ if err := rows.Scan(&roomID, &userID); err != nil {
+ return nil, err
+ }
+ users := result[roomID]
+ users = append(users, userID)
+ result[roomID] = users
+ }
+ return result, nil
+}
+
+// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
+func (s *currentRoomStateStatements) selectRoomIDsWithMembership(
+ ctx context.Context,
+ txn *sql.Tx,
+ userID string,
+ membership string, // nolint: unparam
+) ([]string, error) {
+ stmt := common.TxStmt(txn, s.selectRoomIDsWithMembershipStmt)
+ rows, err := stmt.QueryContext(ctx, userID, membership)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ var result []string
+ for rows.Next() {
+ var roomID string
+ if err := rows.Scan(&roomID); err != nil {
+ return nil, err
+ }
+ result = append(result, roomID)
+ }
+ return result, nil
+}
+
+// CurrentState returns all the current state events for the given room.
+func (s *currentRoomStateStatements) selectCurrentState(
+ ctx context.Context, txn *sql.Tx, roomID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]gomatrixserverlib.Event, error) {
+ stmt := common.TxStmt(txn, s.selectCurrentStateStmt)
+ rows, err := stmt.QueryContext(ctx, roomID,
+ pq.StringArray(stateFilterPart.Senders),
+ pq.StringArray(stateFilterPart.NotSenders),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
+ stateFilterPart.ContainsURL,
+ stateFilterPart.Limit,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+
+ return rowsToEvents(rows)
+}
+
+func (s *currentRoomStateStatements) deleteRoomStateByEventID(
+ ctx context.Context, txn *sql.Tx, eventID string,
+) error {
+ stmt := common.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
+ _, err := stmt.ExecContext(ctx, eventID)
+ return err
+}
+
+func (s *currentRoomStateStatements) upsertRoomState(
+ ctx context.Context, txn *sql.Tx,
+ event gomatrixserverlib.Event, membership *string, addedAt types.StreamPosition,
+) error {
+ // Parse content as JSON and search for an "url" key
+ containsURL := false
+ var content map[string]interface{}
+ if json.Unmarshal(event.Content(), &content) != nil {
+ // Set containsURL to true if url is present
+ _, containsURL = content["url"]
+ }
+
+ // upsert state event
+ stmt := common.TxStmt(txn, s.upsertRoomStateStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ event.RoomID(),
+ event.EventID(),
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ *event.StateKey(),
+ event.JSON(),
+ membership,
+ addedAt,
+ )
+ return err
+}
+
+func (s *currentRoomStateStatements) selectEventsWithEventIDs(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StreamEvent, error) {
+ stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt)
+ rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ return rowsToStreamEvents(rows)
+}
+
+func rowsToEvents(rows *sql.Rows) ([]gomatrixserverlib.Event, error) {
+ result := []gomatrixserverlib.Event{}
+ for rows.Next() {
+ var eventBytes []byte
+ if err := rows.Scan(&eventBytes); err != nil {
+ return nil, err
+ }
+ // TODO: Handle redacted events
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, ev)
+ }
+ return result, nil
+}
+
+func (s *currentRoomStateStatements) selectStateEvent(
+ ctx context.Context, roomID, evType, stateKey string,
+) (*gomatrixserverlib.Event, error) {
+ stmt := s.selectStateEventStmt
+ var res []byte
+ err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(res, false)
+ return &ev, err
+}
diff --git a/syncapi/storage/sqlite3/filtering.go b/syncapi/storage/sqlite3/filtering.go
new file mode 100644
index 00000000..c4a2f4bf
--- /dev/null
+++ b/syncapi/storage/sqlite3/filtering.go
@@ -0,0 +1,36 @@
+// Copyright 2017 Thibaut CHARLES
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "strings"
+)
+
+// filterConvertWildcardToSQL converts wildcards as defined in
+// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
+// to SQL wildcards that can be used with LIKE()
+func filterConvertTypeWildcardToSQL(values []string) []string {
+ if values == nil {
+ // Return nil instead of []string{} so IS NULL can work correctly when
+ // the return value is passed into SQL queries
+ return nil
+ }
+
+ ret := make([]string, len(values))
+ for i := range values {
+ ret[i] = strings.Replace(values[i], "*", "%", -1)
+ }
+ return ret
+}
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
new file mode 100644
index 00000000..74dba245
--- /dev/null
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -0,0 +1,157 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const inviteEventsSchema = `
+CREATE TABLE IF NOT EXISTS syncapi_invite_events (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ event_json TEXT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS syncapi_invites_target_user_id_idx ON syncapi_invite_events (target_user_id, id);
+CREATE INDEX IF NOT EXISTS syncapi_invites_event_id_idx ON syncapi_invite_events (event_id);
+`
+
+const insertInviteEventSQL = "" +
+ "INSERT INTO syncapi_invite_events" +
+ " (room_id, event_id, target_user_id, event_json)" +
+ " VALUES ($1, $2, $3, $4)"
+
+const selectLastInsertedInviteEventSQL = "" +
+ "SELECT id FROM syncapi_invite_events WHERE rowid = last_insert_rowid()"
+
+const deleteInviteEventSQL = "" +
+ "DELETE FROM syncapi_invite_events WHERE event_id = $1"
+
+const selectInviteEventsInRangeSQL = "" +
+ "SELECT room_id, event_json FROM syncapi_invite_events" +
+ " WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
+ " ORDER BY id DESC"
+
+const selectMaxInviteIDSQL = "" +
+ "SELECT MAX(id) FROM syncapi_invite_events"
+
+type inviteEventsStatements struct {
+ streamIDStatements *streamIDStatements
+ insertInviteEventStmt *sql.Stmt
+ selectLastInsertedInviteEventStmt *sql.Stmt
+ selectInviteEventsInRangeStmt *sql.Stmt
+ deleteInviteEventStmt *sql.Stmt
+ selectMaxInviteIDStmt *sql.Stmt
+}
+
+func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(inviteEventsSchema)
+ if err != nil {
+ return
+ }
+ if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil {
+ return
+ }
+ if s.selectLastInsertedInviteEventStmt, err = db.Prepare(selectLastInsertedInviteEventSQL); err != nil {
+ return
+ }
+ if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil {
+ return
+ }
+ if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil {
+ return
+ }
+ if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *inviteEventsStatements) insertInviteEvent(
+ ctx context.Context, inviteEvent gomatrixserverlib.Event,
+) (streamPos types.StreamPosition, err error) {
+ _, err = s.insertInviteEventStmt.ExecContext(
+ ctx,
+ inviteEvent.RoomID(),
+ inviteEvent.EventID(),
+ *inviteEvent.StateKey(),
+ inviteEvent.JSON(),
+ )
+ if err != nil {
+ return
+ }
+ err = s.selectLastInsertedInviteEventStmt.QueryRowContext(ctx).Scan(&streamPos)
+ return
+}
+
+func (s *inviteEventsStatements) deleteInviteEvent(
+ ctx context.Context, inviteEventID string,
+) error {
+ _, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID)
+ return err
+}
+
+// selectInviteEventsInRange returns a map of room ID to invite event for the
+// active invites for the target user ID in the supplied range.
+func (s *inviteEventsStatements) selectInviteEventsInRange(
+ ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition,
+) (map[string]gomatrixserverlib.Event, error) {
+ stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
+ rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ result := map[string]gomatrixserverlib.Event{}
+ for rows.Next() {
+ var (
+ roomID string
+ eventJSON []byte
+ )
+ if err = rows.Scan(&roomID, &eventJSON); err != nil {
+ return nil, err
+ }
+
+ event, err := gomatrixserverlib.NewEventFromTrustedJSON(eventJSON, false)
+ if err != nil {
+ return nil, err
+ }
+
+ result[roomID] = event
+ }
+ return result, nil
+}
+
+func (s *inviteEventsStatements) selectMaxInviteID(
+ ctx context.Context, txn *sql.Tx,
+) (id int64, err error) {
+ var nullableID sql.NullInt64
+ stmt := common.TxStmt(txn, s.selectMaxInviteIDStmt)
+ err = stmt.QueryRowContext(ctx).Scan(&nullableID)
+ if nullableID.Valid {
+ id = nullableID.Int64
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
new file mode 100644
index 00000000..8c01f2ce
--- /dev/null
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -0,0 +1,411 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "sort"
+
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/types"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+const outputRoomEventsSchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_id TEXT NOT NULL UNIQUE,
+ room_id TEXT NOT NULL,
+ event_json TEXT NOT NULL,
+ type TEXT NOT NULL,
+ sender TEXT NOT NULL,
+ contains_url BOOL NOT NULL,
+ add_state_ids TEXT[],
+ remove_state_ids TEXT[],
+ session_id BIGINT,
+ transaction_id TEXT,
+ exclude_from_sync BOOL DEFAULT FALSE
+);
+`
+
+const insertEventSQL = "" +
+ "INSERT INTO syncapi_output_room_events (" +
+ "id, room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
+ ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " +
+ "ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = $11"
+
+const selectLastInsertedEventSQL = "" +
+ "SELECT id FROM syncapi_output_room_events WHERE rowid = last_insert_rowid()"
+
+const selectEventsSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
+
+const selectRecentEventsSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3" +
+ " ORDER BY id DESC LIMIT $4"
+
+const selectRecentEventsForSyncSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
+ " ORDER BY id DESC LIMIT $4"
+
+const selectEarlyEventsSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3" +
+ " ORDER BY id ASC LIMIT $4"
+
+const selectMaxEventIDSQL = "" +
+ "SELECT MAX(id) FROM syncapi_output_room_events"
+
+// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
+/*
+ $1 = oldPos,
+ $2 = newPos,
+ $3 = pq.StringArray(stateFilterPart.Senders),
+ $4 = pq.StringArray(stateFilterPart.NotSenders),
+ $5 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
+ $6 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
+ $7 = stateFilterPart.ContainsURL,
+ $8 = stateFilterPart.Limit,
+*/
+const selectStateInRangeSQL = "" +
+ "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
+ " FROM syncapi_output_room_events" +
+ " WHERE (id > $1 AND id <= $2)" + // old/new pos
+ " AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
+ /* " AND ( $3 IS NULL OR sender IN ($3) )" + // sender
+ " AND ( $4 IS NULL OR NOT(sender IN ($4)) )" + // not sender
+ " AND ( $5 IS NULL OR type IN ($5) )" + // type
+ " AND ( $6 IS NULL OR NOT(type IN ($6)) )" + // not type
+ " AND ( $7 IS NULL OR contains_url = $7)" + // contains URL? */
+ " ORDER BY id ASC" +
+ " LIMIT $8" // limit
+
+type outputRoomEventsStatements struct {
+ streamIDStatements *streamIDStatements
+ insertEventStmt *sql.Stmt
+ selectLastInsertedEventStmt *sql.Stmt
+ selectEventsStmt *sql.Stmt
+ selectMaxEventIDStmt *sql.Stmt
+ selectRecentEventsStmt *sql.Stmt
+ selectRecentEventsForSyncStmt *sql.Stmt
+ selectEarlyEventsStmt *sql.Stmt
+ selectStateInRangeStmt *sql.Stmt
+}
+
+func (s *outputRoomEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) {
+ s.streamIDStatements = streamID
+ _, err = db.Exec(outputRoomEventsSchema)
+ if err != nil {
+ return
+ }
+ if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
+ return
+ }
+ if s.selectLastInsertedEventStmt, err = db.Prepare(selectLastInsertedEventSQL); err != nil {
+ return
+ }
+ if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil {
+ return
+ }
+ if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil {
+ return
+ }
+ if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
+ return
+ }
+ if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil {
+ return
+ }
+ if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil {
+ return
+ }
+ if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
+ return
+ }
+ return
+}
+
+// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
+// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
+// two positions, only the most recent state is returned.
+func (s *outputRoomEventsStatements) selectStateInRange(
+ ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
+ stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
+
+ rows, err := stmt.QueryContext(
+ ctx, oldPos, newPos,
+ /*pq.StringArray(stateFilterPart.Senders),
+ pq.StringArray(stateFilterPart.NotSenders),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
+ pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
+ stateFilterPart.ContainsURL,*/
+ stateFilterPart.Limit,
+ )
+ if err != nil {
+ return nil, nil, err
+ }
+ // Fetch all the state change events for all rooms between the two positions then loop each event and:
+ // - Keep a cache of the event by ID (99% of state change events are for the event itself)
+ // - For each room ID, build up an array of event IDs which represents cumulative adds/removes
+ // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
+ // if they aren't in the event ID cache. We don't handle state deletion yet.
+ eventIDToEvent := make(map[string]types.StreamEvent)
+
+ // RoomID => A set (map[string]bool) of state event IDs which are between the two positions
+ stateNeeded := make(map[string]map[string]bool)
+
+ for rows.Next() {
+ var (
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ addIDs pq.StringArray
+ delIDs pq.StringArray
+ )
+ if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil {
+ return nil, nil, err
+ }
+ // Sanity check for deleted state and whine if we see it. We don't need to do anything
+ // since it'll just mark the event as not being needed.
+ if len(addIDs) < len(delIDs) {
+ log.WithFields(log.Fields{
+ "since": oldPos,
+ "current": newPos,
+ "adds": addIDs,
+ "dels": delIDs,
+ }).Warn("StateBetween: ignoring deleted state")
+ }
+
+ // TODO: Handle redacted events
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
+ if err != nil {
+ return nil, nil, err
+ }
+ needSet := stateNeeded[ev.RoomID()]
+ if needSet == nil { // make set if required
+ needSet = make(map[string]bool)
+ }
+ for _, id := range delIDs {
+ needSet[id] = false
+ }
+ for _, id := range addIDs {
+ needSet[id] = true
+ }
+ stateNeeded[ev.RoomID()] = needSet
+
+ eventIDToEvent[ev.EventID()] = types.StreamEvent{
+ Event: ev,
+ StreamPosition: streamPos,
+ ExcludeFromSync: excludeFromSync,
+ }
+ }
+
+ return stateNeeded, eventIDToEvent, nil
+}
+
+// MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied,
+// then this function should only ever be used at startup, as it will race with inserting events if it is
+// done afterwards. If there are no inserted events, 0 is returned.
+func (s *outputRoomEventsStatements) selectMaxEventID(
+ ctx context.Context, txn *sql.Tx,
+) (id int64, err error) {
+ var nullableID sql.NullInt64
+ stmt := common.TxStmt(txn, s.selectMaxEventIDStmt)
+ err = stmt.QueryRowContext(ctx).Scan(&nullableID)
+ if nullableID.Valid {
+ id = nullableID.Int64
+ }
+ return
+}
+
+// InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position
+// of the inserted event.
+func (s *outputRoomEventsStatements) insertEvent(
+ ctx context.Context, txn *sql.Tx,
+ event *gomatrixserverlib.Event, addState, removeState []string,
+ transactionID *api.TransactionID, excludeFromSync bool,
+) (streamPos types.StreamPosition, err error) {
+ var txnID *string
+ var sessionID *int64
+ if transactionID != nil {
+ sessionID = &transactionID.SessionID
+ txnID = &transactionID.TransactionID
+ }
+
+ // Parse content as JSON and search for an "url" key
+ containsURL := false
+ var content map[string]interface{}
+ if json.Unmarshal(event.Content(), &content) != nil {
+ // Set containsURL to true if url is present
+ _, containsURL = content["url"]
+ }
+
+ streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
+
+ insertStmt := common.TxStmt(txn, s.insertEventStmt)
+ selectStmt := common.TxStmt(txn, s.selectLastInsertedEventStmt)
+ _, err = insertStmt.ExecContext(
+ ctx,
+ streamPos,
+ event.RoomID(),
+ event.EventID(),
+ event.JSON(),
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ pq.StringArray(addState),
+ pq.StringArray(removeState),
+ sessionID,
+ txnID,
+ excludeFromSync,
+ )
+ if err != nil {
+ return
+ }
+ err = selectStmt.QueryRowContext(ctx).Scan(&streamPos)
+ return
+}
+
+// selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'.
+// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude
+// from sync.
+func (s *outputRoomEventsStatements) selectRecentEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomID string, fromPos, toPos types.StreamPosition, limit int,
+ chronologicalOrder bool, onlySyncEvents bool,
+) ([]types.StreamEvent, error) {
+ var stmt *sql.Stmt
+ if onlySyncEvents {
+ stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt)
+ } else {
+ stmt = common.TxStmt(txn, s.selectRecentEventsStmt)
+ }
+
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ events, err := rowsToStreamEvents(rows)
+ if err != nil {
+ return nil, err
+ }
+ if chronologicalOrder {
+ // The events need to be returned from oldest to latest, which isn't
+ // necessary the way the SQL query returns them, so a sort is necessary to
+ // ensure the events are in the right order in the slice.
+ sort.SliceStable(events, func(i int, j int) bool {
+ return events[i].StreamPosition < events[j].StreamPosition
+ })
+ }
+ return events, nil
+}
+
+// selectEarlyEvents returns the earliest events in the given room, starting
+// from a given position, up to a maximum of 'limit'.
+func (s *outputRoomEventsStatements) selectEarlyEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomID string, fromPos, toPos types.StreamPosition, limit int,
+) ([]types.StreamEvent, error) {
+ stmt := common.TxStmt(txn, s.selectEarlyEventsStmt)
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ events, err := rowsToStreamEvents(rows)
+ if err != nil {
+ return nil, err
+ }
+ // The events need to be returned from oldest to latest, which isn't
+ // necessarily the way the SQL query returns them, so a sort is necessary to
+ // ensure the events are in the right order in the slice.
+ sort.SliceStable(events, func(i int, j int) bool {
+ return events[i].StreamPosition < events[j].StreamPosition
+ })
+ return events, nil
+}
+
+// selectEvents returns the events for the given event IDs. If an event is
+// missing from the database, it will be omitted.
+func (s *outputRoomEventsStatements) selectEvents(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StreamEvent, error) {
+ var returnEvents []types.StreamEvent
+ stmt := common.TxStmt(txn, s.selectEventsStmt)
+ for _, eventID := range eventIDs {
+ rows, err := stmt.QueryContext(ctx, eventID)
+ if err != nil {
+ return nil, err
+ }
+ if streamEvents, err := rowsToStreamEvents(rows); err == nil {
+ returnEvents = append(returnEvents, streamEvents...)
+ }
+ rows.Close() // nolint: errcheck
+ }
+ return returnEvents, nil
+}
+
+func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
+ var result []types.StreamEvent
+ for rows.Next() {
+ var (
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ sessionID *int64
+ txnID *string
+ transactionID *api.TransactionID
+ )
+ if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
+ return nil, err
+ }
+ // TODO: Handle redacted events
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON(eventBytes, false)
+ if err != nil {
+ return nil, err
+ }
+
+ if sessionID != nil && txnID != nil {
+ transactionID = &api.TransactionID{
+ SessionID: *sessionID,
+ TransactionID: *txnID,
+ }
+ }
+
+ result = append(result, types.StreamEvent{
+ Event: ev,
+ StreamPosition: streamPos,
+ TransactionID: transactionID,
+ ExcludeFromSync: excludeFromSync,
+ })
+ }
+ return result, nil
+}
diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go
new file mode 100644
index 00000000..f7075bd6
--- /dev/null
+++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go
@@ -0,0 +1,192 @@
+// Copyright 2018 New Vector Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const outputRoomEventsTopologySchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
+ event_id TEXT PRIMARY KEY,
+ topological_position BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+
+ UNIQUE(topological_position, room_id)
+);
+-- The topological order will be used in events selection and ordering
+-- CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id);
+`
+
+const insertEventInTopologySQL = "" +
+ "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" +
+ " VALUES ($1, $2, $3)" +
+ " ON CONFLICT (topological_position, room_id) DO UPDATE SET event_id = $1"
+
+const selectEventIDsInRangeASCSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
+ " ORDER BY topological_position ASC LIMIT $4"
+
+const selectEventIDsInRangeDESCSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
+ " ORDER BY topological_position DESC LIMIT $4"
+
+const selectPositionInTopologySQL = "" +
+ "SELECT topological_position FROM syncapi_output_room_events_topology" +
+ " WHERE event_id = $1"
+
+const selectMaxPositionInTopologySQL = "" +
+ "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1"
+
+const selectEventIDsFromPositionSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position = $2"
+
+type outputRoomEventsTopologyStatements struct {
+ insertEventInTopologyStmt *sql.Stmt
+ selectEventIDsInRangeASCStmt *sql.Stmt
+ selectEventIDsInRangeDESCStmt *sql.Stmt
+ selectPositionInTopologyStmt *sql.Stmt
+ selectMaxPositionInTopologyStmt *sql.Stmt
+ selectEventIDsFromPositionStmt *sql.Stmt
+}
+
+func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(outputRoomEventsTopologySchema)
+ if err != nil {
+ return
+ }
+ if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
+ return
+ }
+ if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
+ return
+ }
+ if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
+ return
+ }
+ if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
+ return
+ }
+ if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
+ return
+ }
+ if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil {
+ return
+ }
+ return
+}
+
+// insertEventInTopology inserts the given event in the room's topology, based
+// on the event's depth.
+func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
+ ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event,
+) (err error) {
+ stmt := common.TxStmt(txn, s.insertEventInTopologyStmt)
+ _, err = stmt.ExecContext(
+ ctx, event.EventID(), event.Depth(), event.RoomID(),
+ )
+ return
+}
+
+// selectEventIDsInRange selects the IDs of events which positions are within a
+// given range in a given room's topological order.
+// Returns an empty slice if no events match the given range.
+func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
+ ctx context.Context, txn *sql.Tx, roomID string,
+ fromPos, toPos types.StreamPosition,
+ limit int, chronologicalOrder bool,
+) (eventIDs []string, err error) {
+ // Decide on the selection's order according to whether chronological order
+ // is requested or not.
+ var stmt *sql.Stmt
+ if chronologicalOrder {
+ stmt = common.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
+ } else {
+ stmt = common.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
+ }
+
+ // Query the event IDs.
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err == sql.ErrNoRows {
+ // If no event matched the request, return an empty slice.
+ return []string{}, nil
+ } else if err != nil {
+ return
+ }
+
+ // Return the IDs.
+ var eventID string
+ for rows.Next() {
+ if err = rows.Scan(&eventID); err != nil {
+ return
+ }
+ eventIDs = append(eventIDs, eventID)
+ }
+
+ return
+}
+
+// selectPositionInTopology returns the position of a given event in the
+// topology of the room it belongs to.
+func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
+ ctx context.Context, txn *sql.Tx, eventID string,
+) (pos types.StreamPosition, err error) {
+ stmt := common.TxStmt(txn, s.selectPositionInTopologyStmt)
+ err = stmt.QueryRowContext(ctx, eventID).Scan(&pos)
+ return
+}
+
+func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology(
+ ctx context.Context, txn *sql.Tx, roomID string,
+) (pos types.StreamPosition, err error) {
+ stmt := common.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
+ err = stmt.QueryRowContext(ctx, roomID).Scan(&pos)
+ return
+}
+
+// selectEventIDsFromPosition returns the IDs of all events that have a given
+// position in the topology of a given room.
+func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition(
+ ctx context.Context, txn *sql.Tx, roomID string, pos types.StreamPosition,
+) (eventIDs []string, err error) {
+ // Query the event IDs.
+ stmt := common.TxStmt(txn, s.selectEventIDsFromPositionStmt)
+ rows, err := stmt.QueryContext(ctx, roomID, pos)
+ if err == sql.ErrNoRows {
+ // If no event matched the request, return an empty slice.
+ return []string{}, nil
+ } else if err != nil {
+ return
+ }
+ // Return the IDs.
+ var eventID string
+ for rows.Next() {
+ if err = rows.Scan(&eventID); err != nil {
+ return
+ }
+ eventIDs = append(eventIDs, eventID)
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
new file mode 100644
index 00000000..260f7a95
--- /dev/null
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -0,0 +1,58 @@
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+const streamIDTableSchema = `
+-- Global stream ID counter, used by other tables.
+CREATE TABLE IF NOT EXISTS syncapi_stream_id (
+ stream_name TEXT NOT NULL PRIMARY KEY,
+ stream_id INT DEFAULT 0,
+
+ UNIQUE(stream_name)
+);
+INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
+ ON CONFLICT DO NOTHING;
+`
+
+const increaseStreamIDStmt = "" +
+ "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
+
+const selectStreamIDStmt = "" +
+ "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
+
+type streamIDStatements struct {
+ increaseStreamIDStmt *sql.Stmt
+ selectStreamIDStmt *sql.Stmt
+}
+
+func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(streamIDTableSchema)
+ if err != nil {
+ return
+ }
+ if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
+ return
+ }
+ if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
+ return
+ }
+ return
+}
+
+func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+ increaseStmt := common.TxStmt(txn, s.increaseStreamIDStmt)
+ selectStmt := common.TxStmt(txn, s.selectStreamIDStmt)
+ if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
+ return
+ }
+ if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
+ return
+ }
+ return
+}
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
new file mode 100644
index 00000000..8cfc1884
--- /dev/null
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -0,0 +1,1197 @@
+// Copyright 2017-2018 New Vector Ltd
+// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/url"
+ "time"
+
+ "github.com/sirupsen/logrus"
+
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/roomserver/api"
+
+ // Import the postgres database driver.
+ _ "github.com/lib/pq"
+ _ "github.com/mattn/go-sqlite3"
+
+ "github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/typingserver/cache"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type stateDelta struct {
+ roomID string
+ stateEvents []gomatrixserverlib.Event
+ membership string
+ // The PDU stream position of the latest membership event for this user, if applicable.
+ // Can be 0 if there is no membership event in this delta.
+ membershipPos types.StreamPosition
+}
+
+// SyncServerDatasource represents a sync server datasource which manages
+// both the database for PDUs and caches for EDUs.
+type SyncServerDatasource struct {
+ db *sql.DB
+ common.PartitionOffsetStatements
+ streamID streamIDStatements
+ accountData accountDataStatements
+ events outputRoomEventsStatements
+ roomstate currentRoomStateStatements
+ invites inviteEventsStatements
+ typingCache *cache.TypingCache
+ topology outputRoomEventsTopologyStatements
+ backwardExtremities backwardExtremitiesStatements
+}
+
+// NewSyncServerDatasource creates a new sync server database
+// nolint: gocyclo
+func NewSyncServerDatasource(dataSourceName string) (*SyncServerDatasource, error) {
+ var d SyncServerDatasource
+ uri, err := url.Parse(dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ var cs string
+ if uri.Opaque != "" { // file:filename.db
+ cs = uri.Opaque
+ } else if uri.Path != "" { // file:///path/to/filename.db
+ cs = uri.Path
+ } else {
+ return nil, errors.New("no filename or path in connect string")
+ }
+ if d.db, err = sql.Open("sqlite3", cs); err != nil {
+ return nil, err
+ }
+ if err = d.prepare(); err != nil {
+ return nil, err
+ }
+ d.typingCache = cache.NewTypingCache()
+ return &d, nil
+}
+
+func (d *SyncServerDatasource) prepare() (err error) {
+ if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
+ return err
+ }
+ if err = d.streamID.prepare(d.db); err != nil {
+ return err
+ }
+ if err = d.accountData.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err = d.events.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err := d.roomstate.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err := d.invites.prepare(d.db, &d.streamID); err != nil {
+ return err
+ }
+ if err := d.topology.prepare(d.db); err != nil {
+ return err
+ }
+ if err := d.backwardExtremities.prepare(d.db); err != nil {
+ return err
+ }
+ return nil
+}
+
+// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
+func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
+ return d.roomstate.selectJoinedUsers(ctx)
+}
+
+// Events lookups a list of event by their event ID.
+// Returns a list of events matching the requested IDs found in the database.
+// If an event is not found in the database then it will be omitted from the list.
+// Returns an error if there was a problem talking with the database.
+// Does not include any transaction IDs in the returned events.
+func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
+ streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ // We don't include a device here as we only include transaction IDs in
+ // incremental syncs.
+ return d.StreamEventsToEvents(nil, streamEvents), nil
+}
+
+func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.Event) error {
+ // If the event is already known as a backward extremity, don't consider
+ // it as such anymore now that we have it.
+ isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID())
+ if err != nil {
+ return err
+ }
+ if isBackwardExtremity {
+ if err = d.backwardExtremities.deleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+
+ // Check if we have all of the event's previous events. If an event is
+ // missing, add it to the room's backward extremities.
+ prevEvents, err := d.events.selectEvents(ctx, txn, ev.PrevEventIDs())
+ if err != nil {
+ return err
+ }
+ var found bool
+ for _, eID := range ev.PrevEventIDs() {
+ found = false
+ for _, prevEv := range prevEvents {
+ if eID == prevEv.EventID() {
+ found = true
+ }
+ }
+
+ // If the event is missing, consider it a backward extremity.
+ if !found {
+ if err = d.backwardExtremities.insertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
+// when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
+// Returns an error if there was a problem inserting this event.
+func (d *SyncServerDatasource) WriteEvent(
+ ctx context.Context,
+ ev *gomatrixserverlib.Event,
+ addStateEvents []gomatrixserverlib.Event,
+ addStateEventIDs, removeStateEventIDs []string,
+ transactionID *api.TransactionID, excludeFromSync bool,
+) (pduPosition types.StreamPosition, returnErr error) {
+ returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ pos, err := d.events.insertEvent(
+ ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
+ )
+ if err != nil {
+ fmt.Println("d.events.insertEvent:", err)
+ return err
+ }
+ pduPosition = pos
+
+ if err = d.topology.insertEventInTopology(ctx, txn, ev); err != nil {
+ fmt.Println("d.topology.insertEventInTopology:", err)
+ return err
+ }
+
+ if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
+ fmt.Println("d.handleBackwardExtremities:", err)
+ return err
+ }
+
+ if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
+ // Nothing to do, the event may have just been a message event.
+ fmt.Println("nothing to do")
+ return nil
+ }
+
+ return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition)
+ })
+
+ return pduPosition, returnErr
+}
+
+func (d *SyncServerDatasource) updateRoomState(
+ ctx context.Context, txn *sql.Tx,
+ removedEventIDs []string,
+ addedEvents []gomatrixserverlib.Event,
+ pduPosition types.StreamPosition,
+) error {
+ // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
+ for _, eventID := range removedEventIDs {
+ if err := d.roomstate.deleteRoomStateByEventID(ctx, txn, eventID); err != nil {
+ return err
+ }
+ }
+
+ for _, event := range addedEvents {
+ if event.StateKey() == nil {
+ // ignore non state events
+ continue
+ }
+ var membership *string
+ if event.Type() == "m.room.member" {
+ value, err := event.Membership()
+ if err != nil {
+ return err
+ }
+ membership = &value
+ }
+ if err := d.roomstate.upsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
+// If no event could be found, returns nil
+// If there was an issue during the retrieval, returns an error
+func (d *SyncServerDatasource) GetStateEvent(
+ ctx context.Context, roomID, evType, stateKey string,
+) (*gomatrixserverlib.Event, error) {
+ return d.roomstate.selectStateEvent(ctx, roomID, evType, stateKey)
+}
+
+// GetStateEventsForRoom fetches the state events for a given room.
+// Returns an empty slice if no state events could be found for this room.
+// Returns an error if there was an issue with the retrieval.
+func (d *SyncServerDatasource) GetStateEventsForRoom(
+ ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter,
+) (stateEvents []gomatrixserverlib.Event, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
+ return err
+ })
+ return
+}
+
+// GetEventsInRange retrieves all of the events on a given ordering using the
+// given extremities and limit.
+func (d *SyncServerDatasource) GetEventsInRange(
+ ctx context.Context,
+ from, to *types.PaginationToken,
+ roomID string, limit int,
+ backwardOrdering bool,
+) (events []types.StreamEvent, err error) {
+ // If the pagination token's type is types.PaginationTokenTypeTopology, the
+ // events must be retrieved from the rooms' topology table rather than the
+ // table contaning the syncapi server's whole stream of events.
+ if from.Type == types.PaginationTokenTypeTopology {
+ // Determine the backward and forward limit, i.e. the upper and lower
+ // limits to the selection in the room's topology, from the direction.
+ var backwardLimit, forwardLimit types.StreamPosition
+ if backwardOrdering {
+ // Backward ordering is antichronological (latest event to oldest
+ // one).
+ backwardLimit = to.PDUPosition
+ forwardLimit = from.PDUPosition
+ } else {
+ // Forward ordering is chronological (oldest event to latest one).
+ backwardLimit = from.PDUPosition
+ forwardLimit = to.PDUPosition
+ }
+
+ // Select the event IDs from the defined range.
+ var eIDs []string
+ eIDs, err = d.topology.selectEventIDsInRange(
+ ctx, nil, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering,
+ )
+ if err != nil {
+ return
+ }
+
+ // Retrieve the events' contents using their IDs.
+ events, err = d.events.selectEvents(ctx, nil, eIDs)
+ return
+ }
+
+ // If the pagination token's type is types.PaginationTokenTypeStream, the
+ // events must be retrieved from the table contaning the syncapi server's
+ // whole stream of events.
+
+ if backwardOrdering {
+ // When using backward ordering, we want the most recent events first.
+ if events, err = d.events.selectRecentEvents(
+ ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false,
+ ); err != nil {
+ return
+ }
+ } else {
+ // When using forward ordering, we want the least recent events first.
+ if events, err = d.events.selectEarlyEvents(
+ ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit,
+ ); err != nil {
+ return
+ }
+ }
+
+ return
+}
+
+// SyncPosition returns the latest positions for syncing.
+func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) {
+ return d.syncPositionTx(ctx, nil)
+}
+
+// BackwardExtremitiesForRoom returns the event IDs of all of the backward
+// extremities we know of for a given room.
+func (d *SyncServerDatasource) BackwardExtremitiesForRoom(
+ ctx context.Context, roomID string,
+) (backwardExtremities []string, err error) {
+ return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, nil, roomID)
+}
+
+// MaxTopologicalPosition returns the highest topological position for a given
+// room.
+func (d *SyncServerDatasource) MaxTopologicalPosition(
+ ctx context.Context, roomID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectMaxPositionInTopology(ctx, nil, roomID)
+}
+
+// EventsAtTopologicalPosition returns all of the events matching a given
+// position in the topology of a given room.
+func (d *SyncServerDatasource) EventsAtTopologicalPosition(
+ ctx context.Context, roomID string, pos types.StreamPosition,
+) ([]types.StreamEvent, error) {
+ eIDs, err := d.topology.selectEventIDsFromPosition(ctx, nil, roomID, pos)
+ if err != nil {
+ return nil, err
+ }
+
+ return d.events.selectEvents(ctx, nil, eIDs)
+}
+
+func (d *SyncServerDatasource) EventPositionInTopology(
+ ctx context.Context, eventID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectPositionInTopology(ctx, nil, eventID)
+}
+
+// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
+func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
+ return d.syncStreamPositionTx(ctx, nil)
+}
+
+func (d *SyncServerDatasource) syncStreamPositionTx(
+ ctx context.Context, txn *sql.Tx,
+) (types.StreamPosition, error) {
+ maxID, err := d.events.selectMaxEventID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxAccountDataID > maxID {
+ maxID = maxAccountDataID
+ }
+ maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxInviteID > maxID {
+ maxID = maxInviteID
+ }
+ return types.StreamPosition(maxID), nil
+}
+
+func (d *SyncServerDatasource) syncPositionTx(
+ ctx context.Context, txn *sql.Tx,
+) (sp types.PaginationToken, err error) {
+
+ maxEventID, err := d.events.selectMaxEventID(ctx, txn)
+ if err != nil {
+ return sp, err
+ }
+ maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
+ if err != nil {
+ return sp, err
+ }
+ if maxAccountDataID > maxEventID {
+ maxEventID = maxAccountDataID
+ }
+ maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
+ if err != nil {
+ return sp, err
+ }
+ if maxInviteID > maxEventID {
+ maxEventID = maxInviteID
+ }
+ sp.PDUPosition = types.StreamPosition(maxEventID)
+ sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition())
+ return
+}
+
+// addPDUDeltaToResponse adds all PDU deltas to a sync response.
+// IDs of all rooms the user joined are returned so EDU deltas can be added for them.
+func (d *SyncServerDatasource) addPDUDeltaToResponse(
+ ctx context.Context,
+ device authtypes.Device,
+ fromPos, toPos types.StreamPosition,
+ numRecentEventsPerRoom int,
+ wantFullState bool,
+ res *types.Response,
+) (joinedRoomIDs []string, err error) {
+ txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
+ if err != nil {
+ return nil, err
+ }
+ var succeeded bool
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
+
+ stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
+
+ // Work out which rooms to return in the response. This is done by getting not only the currently
+ // joined rooms, but also which rooms have membership transitions for this user between the 2 PDU stream positions.
+ // This works out what the 'state' key should be for each room as well as which membership block
+ // to put the room into.
+ var deltas []stateDelta
+ if !wantFullState {
+ deltas, joinedRoomIDs, err = d.getStateDeltas(
+ ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart,
+ )
+ } else {
+ deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync(
+ ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilterPart,
+ )
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ for _, delta := range deltas {
+ err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // TODO: This should be done in getStateDeltas
+ if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil {
+ return nil, err
+ }
+
+ succeeded = true
+ return joinedRoomIDs, nil
+}
+
+// addTypingDeltaToResponse adds all typing notifications to a sync response
+// since the specified position.
+func (d *SyncServerDatasource) addTypingDeltaToResponse(
+ since types.PaginationToken,
+ joinedRoomIDs []string,
+ res *types.Response,
+) error {
+ var jr types.JoinResponse
+ var ok bool
+ var err error
+ for _, roomID := range joinedRoomIDs {
+ if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
+ roomID, int64(since.EDUTypingPosition),
+ ); updated {
+ ev := gomatrixserverlib.ClientEvent{
+ Type: gomatrixserverlib.MTyping,
+ }
+ ev.Content, err = json.Marshal(map[string]interface{}{
+ "user_ids": typingUsers,
+ })
+ if err != nil {
+ return err
+ }
+
+ if jr, ok = res.Rooms.Join[roomID]; !ok {
+ jr = *types.NewJoinResponse()
+ }
+ jr.Ephemeral.Events = append(jr.Ephemeral.Events, ev)
+ res.Rooms.Join[roomID] = jr
+ }
+ }
+ return nil
+}
+
+// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
+// the positions of that type are not equal in fromPos and toPos.
+func (d *SyncServerDatasource) addEDUDeltaToResponse(
+ fromPos, toPos types.PaginationToken,
+ joinedRoomIDs []string,
+ res *types.Response,
+) (err error) {
+
+ if fromPos.EDUTypingPosition != toPos.EDUTypingPosition {
+ err = d.addTypingDeltaToResponse(
+ fromPos, joinedRoomIDs, res,
+ )
+ }
+
+ return
+}
+
+// IncrementalSync returns all the data needed in order to create an incremental
+// sync response for the given user. Events returned will include any client
+// transaction IDs associated with the given device. These transaction IDs come
+// from when the device sent the event via an API that included a transaction
+// ID.
+func (d *SyncServerDatasource) IncrementalSync(
+ ctx context.Context,
+ device authtypes.Device,
+ fromPos, toPos types.PaginationToken,
+ numRecentEventsPerRoom int,
+ wantFullState bool,
+) (*types.Response, error) {
+ nextBatchPos := fromPos.WithUpdates(toPos)
+ res := types.NewResponse(nextBatchPos)
+
+ var joinedRoomIDs []string
+ var err error
+ if fromPos.PDUPosition != toPos.PDUPosition || wantFullState {
+ joinedRoomIDs, err = d.addPDUDeltaToResponse(
+ ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res,
+ )
+ } else {
+ joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(
+ ctx, nil, device.UserID, gomatrixserverlib.Join,
+ )
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ err = d.addEDUDeltaToResponse(
+ fromPos, toPos, joinedRoomIDs, res,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed
+// to it. It returns toPos and joinedRoomIDs for use of adding EDUs.
+func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
+ ctx context.Context,
+ userID string,
+ numRecentEventsPerRoom int,
+) (
+ res *types.Response,
+ toPos types.PaginationToken,
+ joinedRoomIDs []string,
+ err error,
+) {
+ // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have
+ // a consistent view of the database throughout. This includes extracting the sync position.
+ // This does have the unfortunate side-effect that all the matrixy logic resides in this function,
+ // but it's better to not hide the fact that this is being done in a transaction.
+ txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot)
+ if err != nil {
+ return
+ }
+ var succeeded bool
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
+
+ // Get the current sync position which we will base the sync response on.
+ toPos, err = d.syncPositionTx(ctx, txn)
+ if err != nil {
+ return
+ }
+
+ res = types.NewResponse(toPos)
+
+ // Extract room state and recent events for all rooms the user is joined to.
+ joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
+ if err != nil {
+ return
+ }
+ fmt.Println("Joined rooms:", joinedRoomIDs)
+
+ stateFilterPart := gomatrixserverlib.DefaultStateFilter() // TODO: use filter provided in request
+
+ // Build up a /sync response. Add joined rooms.
+ for _, roomID := range joinedRoomIDs {
+ fmt.Println("WE'RE ON", roomID)
+
+ var stateEvents []gomatrixserverlib.Event
+ stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID, &stateFilterPart)
+ if err != nil {
+ fmt.Println("d.roomstate.selectCurrentState:", err)
+ return
+ }
+ //fmt.Println("State events:", stateEvents)
+ // TODO: When filters are added, we may need to call this multiple times to get enough events.
+ // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
+ var recentStreamEvents []types.StreamEvent
+ recentStreamEvents, err = d.events.selectRecentEvents(
+ ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition,
+ numRecentEventsPerRoom, true, true,
+ )
+ if err != nil {
+ fmt.Println("d.events.selectRecentEvents:", err)
+ return
+ }
+ //fmt.Println("Recent stream events:", recentStreamEvents)
+
+ // Retrieve the backward topology position, i.e. the position of the
+ // oldest event in the room's topology.
+ var backwardTopologyPos types.StreamPosition
+ backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
+ if err != nil {
+ fmt.Println("d.topology.selectPositionInTopology:", err)
+ return nil, types.PaginationToken{}, []string{}, err
+ }
+ fmt.Println("Backward topology position:", backwardTopologyPos)
+ if backwardTopologyPos-1 <= 0 {
+ backwardTopologyPos = types.StreamPosition(1)
+ } else {
+ backwardTopologyPos--
+ }
+
+ // We don't include a device here as we don't need to send down
+ // transaction IDs for complete syncs
+ recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
+ stateEvents = removeDuplicates(stateEvents, recentEvents)
+ jr := types.NewJoinResponse()
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
+ jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
+ jr.Timeline.Limited = true
+ jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
+ res.Rooms.Join[roomID] = *jr
+ }
+
+ if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil {
+ fmt.Println("d.addInvitesToResponse:", err)
+ return
+ }
+
+ succeeded = true
+ return res, toPos, joinedRoomIDs, err
+}
+
+// CompleteSync returns a complete /sync API response for the given user.
+func (d *SyncServerDatasource) CompleteSync(
+ ctx context.Context, userID string, numRecentEventsPerRoom int,
+) (*types.Response, error) {
+ res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
+ ctx, userID, numRecentEventsPerRoom,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Use a zero value SyncPosition for fromPos so all EDU states are added.
+ err = d.addEDUDeltaToResponse(
+ types.PaginationToken{}, toPos, joinedRoomIDs, res,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ return res, nil
+}
+
+var txReadOnlySnapshot = sql.TxOptions{
+ // Set the isolation level so that we see a snapshot of the database.
+ // In PostgreSQL repeatable read transactions will see a snapshot taken
+ // at the first query, and since the transaction is read-only it can't
+ // run into any serialisation errors.
+ // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
+ Isolation: sql.LevelRepeatableRead,
+ ReadOnly: true,
+}
+
+// GetAccountDataInRange returns all account data for a given user inserted or
+// updated between two given positions
+// Returns a map following the format data[roomID] = []dataTypes
+// If no data is retrieved, returns an empty map
+// If there was an issue with the retrieval, returns an error
+func (d *SyncServerDatasource) GetAccountDataInRange(
+ ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
+ accountDataFilterPart *gomatrixserverlib.EventFilter,
+) (map[string][]string, error) {
+ return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
+}
+
+// UpsertAccountData keeps track of new or updated account data, by saving the type
+// of the new/updated data, and the user ID and room ID the data is related to (empty)
+// room ID means the data isn't specific to any room)
+// If no data with the given type, user ID and room ID exists in the database,
+// creates a new row, else update the existing one
+// Returns an error if there was an issue with the upsert
+func (d *SyncServerDatasource) UpsertAccountData(
+ ctx context.Context, userID, roomID, dataType string,
+) (sp types.StreamPosition, err error) {
+ txn, err := d.db.BeginTx(ctx, nil)
+ if err != nil {
+ return types.StreamPosition(0), err
+ }
+ var succeeded bool
+ defer func() {
+ txerr := common.EndTransaction(txn, &succeeded)
+ if err == nil && txerr != nil {
+ err = txerr
+ }
+ }()
+ sp, err = d.accountData.insertAccountData(ctx, txn, userID, roomID, dataType)
+ return
+}
+
+// AddInviteEvent stores a new invite event for a user.
+// If the invite was successfully stored this returns the stream ID it was stored at.
+// Returns an error if there was a problem communicating with the database.
+func (d *SyncServerDatasource) AddInviteEvent(
+ ctx context.Context, inviteEvent gomatrixserverlib.Event,
+) (types.StreamPosition, error) {
+ return d.invites.insertInviteEvent(ctx, inviteEvent)
+}
+
+// RetireInviteEvent removes an old invite event from the database.
+// Returns an error if there was a problem communicating with the database.
+func (d *SyncServerDatasource) RetireInviteEvent(
+ ctx context.Context, inviteEventID string,
+) error {
+ // TODO: Record that invite has been retired in a stream so that we can
+ // notify the user in an incremental sync.
+ err := d.invites.deleteInviteEvent(ctx, inviteEventID)
+ return err
+}
+
+func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
+ d.typingCache.SetTimeoutCallback(fn)
+}
+
+// AddTypingUser adds a typing user to the typing cache.
+// Returns the newly calculated sync position for typing notifications.
+func (d *SyncServerDatasource) AddTypingUser(
+ userID, roomID string, expireTime *time.Time,
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime))
+}
+
+// RemoveTypingUser removes a typing user from the typing cache.
+// Returns the newly calculated sync position for typing notifications.
+func (d *SyncServerDatasource) RemoveTypingUser(
+ userID, roomID string,
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID))
+}
+
+func (d *SyncServerDatasource) addInvitesToResponse(
+ ctx context.Context, txn *sql.Tx,
+ userID string,
+ fromPos, toPos types.StreamPosition,
+ res *types.Response,
+) error {
+ invites, err := d.invites.selectInviteEventsInRange(
+ ctx, txn, userID, fromPos, toPos,
+ )
+ if err != nil {
+ return err
+ }
+ for roomID, inviteEvent := range invites {
+ ir := types.NewInviteResponse()
+ ir.InviteState.Events = gomatrixserverlib.ToClientEvents(
+ []gomatrixserverlib.Event{inviteEvent}, gomatrixserverlib.FormatSync,
+ )
+ // TODO: add the invite state from the invite event.
+ res.Rooms.Invite[roomID] = *ir
+ }
+ return nil
+}
+
+// Retrieve the backward topology position, i.e. the position of the
+// oldest event in the room's topology.
+func (d *SyncServerDatasource) getBackwardTopologyPos(
+ ctx context.Context, txn *sql.Tx,
+ events []types.StreamEvent,
+) (pos types.StreamPosition) {
+ if len(events) > 0 {
+ pos, _ = d.topology.selectPositionInTopology(ctx, txn, events[0].EventID())
+ }
+ if pos-1 <= 0 {
+ pos = types.StreamPosition(1)
+ } else {
+ pos = pos - 1
+ }
+ return
+}
+
+// addRoomDeltaToResponse adds a room state delta to a sync response
+func (d *SyncServerDatasource) addRoomDeltaToResponse(
+ ctx context.Context,
+ device *authtypes.Device,
+ txn *sql.Tx,
+ fromPos, toPos types.StreamPosition,
+ delta stateDelta,
+ numRecentEventsPerRoom int,
+ res *types.Response,
+) error {
+ endPos := toPos
+ if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave {
+ // make sure we don't leak recent events after the leave event.
+ // TODO: History visibility makes this somewhat complex to handle correctly. For example:
+ // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
+ // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave
+ // in a single /sync request
+ // This is all "okay" assuming history_visibility == "shared" which it is by default.
+ endPos = delta.membershipPos
+ }
+ recentStreamEvents, err := d.events.selectRecentEvents(
+ ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos),
+ numRecentEventsPerRoom, true, true,
+ )
+ if err != nil {
+ return err
+ }
+ recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
+ delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
+ backwardTopologyPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents)
+
+ switch delta.membership {
+ case gomatrixserverlib.Join:
+ jr := types.NewJoinResponse()
+
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
+ jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
+ jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
+ jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
+ res.Rooms.Join[delta.roomID] = *jr
+ case gomatrixserverlib.Leave:
+ fallthrough // transitions to leave are the same as ban
+ case gomatrixserverlib.Ban:
+ // TODO: recentEvents may contain events that this user is not allowed to see because they are
+ // no longer in the room.
+ lr := types.NewLeaveResponse()
+ lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
+ lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
+ lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
+ lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
+ res.Rooms.Leave[delta.roomID] = *lr
+ }
+
+ return nil
+}
+
+// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
+// Returns a map of room ID to list of events.
+func (d *SyncServerDatasource) fetchStateEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomIDToEventIDSet map[string]map[string]bool,
+ eventIDToEvent map[string]types.StreamEvent,
+) (map[string][]types.StreamEvent, error) {
+ stateBetween := make(map[string][]types.StreamEvent)
+ missingEvents := make(map[string][]string)
+ for roomID, ids := range roomIDToEventIDSet {
+ events := stateBetween[roomID]
+ for id, need := range ids {
+ if !need {
+ continue // deleted state
+ }
+ e, ok := eventIDToEvent[id]
+ if ok {
+ events = append(events, e)
+ } else {
+ m := missingEvents[roomID]
+ m = append(m, id)
+ missingEvents[roomID] = m
+ }
+ }
+ stateBetween[roomID] = events
+ }
+
+ if len(missingEvents) > 0 {
+ // This happens when add_state_ids has an event ID which is not in the provided range.
+ // We need to explicitly fetch them.
+ allMissingEventIDs := []string{}
+ for _, missingEvIDs := range missingEvents {
+ allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
+ }
+ evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
+ if err != nil {
+ return nil, err
+ }
+ // we know we got them all otherwise an error would've been returned, so just loop the events
+ for _, ev := range evs {
+ roomID := ev.RoomID()
+ stateBetween[roomID] = append(stateBetween[roomID], ev)
+ }
+ }
+ return stateBetween, nil
+}
+
+func (d *SyncServerDatasource) fetchMissingStateEvents(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StreamEvent, error) {
+ // Fetch from the events table first so we pick up the stream ID for the
+ // event.
+ events, err := d.events.selectEvents(ctx, txn, eventIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ have := map[string]bool{}
+ for _, event := range events {
+ have[event.EventID()] = true
+ }
+ var missing []string
+ for _, eventID := range eventIDs {
+ if !have[eventID] {
+ missing = append(missing, eventID)
+ }
+ }
+ if len(missing) == 0 {
+ return events, nil
+ }
+
+ // If they are missing from the events table then they should be state
+ // events that we received from outside the main event stream.
+ // These should be in the room state table.
+ stateEvents, err := d.roomstate.selectEventsWithEventIDs(ctx, txn, missing)
+
+ if err != nil {
+ return nil, err
+ }
+ if len(stateEvents) != len(missing) {
+ return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
+ }
+ events = append(events, stateEvents...)
+ return events, nil
+}
+
+// getStateDeltas returns the state deltas between fromPos and toPos,
+// exclusive of oldPos, inclusive of newPos, for the rooms in which
+// the user has new membership events.
+// A list of joined room IDs is also returned in case the caller needs it.
+func (d *SyncServerDatasource) getStateDeltas(
+ ctx context.Context, device *authtypes.Device, txn *sql.Tx,
+ fromPos, toPos types.StreamPosition, userID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]stateDelta, []string, error) {
+ // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
+ // - Get membership list changes for this user in this sync response
+ // - For each room which has membership list changes:
+ // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
+ // If it is, then we need to send the full room state down (and 'limited' is always true).
+ // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
+ // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
+ // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
+ var deltas []stateDelta
+
+ // get all the state events ever between these two positions
+ stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart)
+ if err != nil {
+ return nil, nil, err
+ }
+ state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ for roomID, stateStreamEvents := range state {
+ for _, ev := range stateStreamEvents {
+ // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
+ // We should be checking if the user was already joined at fromPos and not proceed if so. As a result of this,
+ // dupe join events will result in the entire room state coming down to the client again. This is added in
+ // the 'state' part of the response though, so is transparent modulo bandwidth concerns as it is not added to
+ // the timeline.
+ if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
+ if membership == gomatrixserverlib.Join {
+ // send full room state down instead of a delta
+ var s []types.StreamEvent
+ s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart)
+ if err != nil {
+ return nil, nil, err
+ }
+ state[roomID] = s
+ continue // we'll add this room in when we do joined rooms
+ }
+
+ deltas = append(deltas, stateDelta{
+ membership: membership,
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
+ roomID: roomID,
+ })
+ break
+ }
+ }
+ }
+
+ // Add in currently joined rooms
+ joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
+ if err != nil {
+ return nil, nil, err
+ }
+ for _, joinedRoomID := range joinedRoomIDs {
+ deltas = append(deltas, stateDelta{
+ membership: gomatrixserverlib.Join,
+ stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
+ roomID: joinedRoomID,
+ })
+ }
+
+ return deltas, joinedRoomIDs, nil
+}
+
+// getStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
+// requests with full_state=true.
+// Fetches full state for all joined rooms and uses selectStateInRange to get
+// updates for other rooms.
+func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
+ ctx context.Context, device *authtypes.Device, txn *sql.Tx,
+ fromPos, toPos types.StreamPosition, userID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]stateDelta, []string, error) {
+ joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Use a reasonable initial capacity
+ deltas := make([]stateDelta, 0, len(joinedRoomIDs))
+
+ // Add full states for all joined rooms
+ for _, joinedRoomID := range joinedRoomIDs {
+ s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilterPart)
+ if stateErr != nil {
+ return nil, nil, stateErr
+ }
+ deltas = append(deltas, stateDelta{
+ membership: gomatrixserverlib.Join,
+ stateEvents: d.StreamEventsToEvents(device, s),
+ roomID: joinedRoomID,
+ })
+ }
+
+ // Get all the state events ever between these two positions
+ stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilterPart)
+ if err != nil {
+ return nil, nil, err
+ }
+ state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ for roomID, stateStreamEvents := range state {
+ for _, ev := range stateStreamEvents {
+ if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
+ if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
+ deltas = append(deltas, stateDelta{
+ membership: membership,
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
+ roomID: roomID,
+ })
+ }
+
+ break
+ }
+ }
+ }
+
+ return deltas, joinedRoomIDs, nil
+}
+
+func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
+ ctx context.Context, txn *sql.Tx, roomID string,
+ stateFilterPart *gomatrixserverlib.StateFilter,
+) ([]types.StreamEvent, error) {
+ allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
+ if err != nil {
+ return nil, err
+ }
+ s := make([]types.StreamEvent, len(allState))
+ for i := 0; i < len(s); i++ {
+ s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0}
+ }
+ return s, nil
+}
+
+// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
+// matches the streamevent.transactionID device then the transaction ID gets
+// added to the unsigned section of the output event.
+func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event {
+ out := make([]gomatrixserverlib.Event, len(in))
+ for i := 0; i < len(in); i++ {
+ out[i] = in[i].Event
+ if device != nil && in[i].TransactionID != nil {
+ if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
+ err := out[i].SetUnsignedField(
+ "transaction_id", in[i].TransactionID.TransactionID,
+ )
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ "event_id": out[i].EventID(),
+ }).WithError(err).Warnf("Failed to add transaction ID to event")
+ }
+ }
+ }
+ }
+ return out
+}
+
+// There may be some overlap where events in stateEvents are already in recentEvents, so filter
+// them out so we don't include them twice in the /sync response. They should be in recentEvents
+// only, so clients get to the correct state once they have rolled forward.
+func removeDuplicates(stateEvents, recentEvents []gomatrixserverlib.Event) []gomatrixserverlib.Event {
+ for _, recentEv := range recentEvents {
+ if recentEv.StateKey() == nil {
+ continue // not a state event
+ }
+ // TODO: This is a linear scan over all the current state events in this room. This will
+ // be slow for big rooms. We should instead sort the state events by event ID (ORDER BY)
+ // then do a binary search to find matching events, similar to what roomserver does.
+ for j := 0; j < len(stateEvents); j++ {
+ if stateEvents[j].EventID() == recentEv.EventID() {
+ // overwrite the element to remove with the last element then pop the last element.
+ // This is orders of magnitude faster than re-slicing, but doesn't preserve ordering
+ // (we don't care about the order of stateEvents)
+ stateEvents[j] = stateEvents[len(stateEvents)-1]
+ stateEvents = stateEvents[:len(stateEvents)-1]
+ break // there shouldn't be multiple events with the same event ID
+ }
+ }
+ }
+ return stateEvents
+}
+
+// getMembershipFromEvent returns the value of content.membership iff the event is a state event
+// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
+func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) string {
+ if ev.Type() == "m.room.member" && ev.StateKeyEquals(userID) {
+ membership, err := ev.Membership()
+ if err != nil {
+ return ""
+ }
+ return membership
+ }
+ return ""
+}
diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go
index e6392844..c87024b2 100644
--- a/syncapi/storage/storage.go
+++ b/syncapi/storage/storage.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/typingserver/cache"
"github.com/matrix-org/gomatrixserverlib"
@@ -63,6 +64,8 @@ func NewSyncServerDatasource(dataSourceName string) (Database, error) {
switch uri.Scheme {
case "postgres":
return postgres.NewSyncServerDatasource(dataSourceName)
+ case "file":
+ return sqlite3.NewSyncServerDatasource(dataSourceName)
default:
return postgres.NewSyncServerDatasource(dataSourceName)
}
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index 06a8d6d8..22bd239f 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -32,12 +32,12 @@ import (
// RequestPool manages HTTP long-poll connections for /sync
type RequestPool struct {
db storage.Database
- accountDB *accounts.Database
+ accountDB accounts.Database
notifier *Notifier
}
// NewRequestPool makes a new RequestPool
-func NewRequestPool(db storage.Database, n *Notifier, adb *accounts.Database) *RequestPool {
+func NewRequestPool(db storage.Database, n *Notifier, adb accounts.Database) *RequestPool {
return &RequestPool{db, adb, n}
}
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index ecf532ca..1535d2b1 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -36,8 +36,8 @@ import (
// component.
func SetupSyncAPIComponent(
base *basecomponent.BaseDendrite,
- deviceDB *devices.Database,
- accountsDB *accounts.Database,
+ deviceDB devices.Database,
+ accountsDB accounts.Database,
queryAPI api.RoomserverQueryAPI,
federation *gomatrixserverlib.FederationClient,
cfg *config.Dendrite,