From 49f760a30b6496c8b3e1ceaf98dccc4376f6605d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 23 Jan 2020 17:51:10 +0000 Subject: CS API: Support for /messages, fixes for /sync (#847) * Merge forward * Tidy up a bit * TODO: What to do with NextBatch here? * Replace SyncPosition with PaginationToken throughout syncapi * Fix PaginationTokens * Fix lint errors * Add a couple of missing functions into the syncapi external storage interface * Some updates based on review comments from @babolivier * Some updates based on review comments from @babolivier * argh whitespacing * Fix opentracing span * Remove dead code * Don't overshadow err (fix lint issue) * Handle extremities after inserting event into topology * Try insert event topology as ON CONFLICT DO NOTHING * Prevent OOB error in addRoomDeltaToResponse * Thwarted by gocyclo again * Fix NewPaginationTokenFromString, define unit test for it * Update pagination token test * Update sytest-whitelist * Hopefully fix some of the sync batch tokens * Remove extraneous sync position func * Revert to topology tokens in addRoomDeltaToResponse etc * Fix typo * Remove prevPDUPos as dead now that backwardTopologyPos is used instead * Fix selectEventsWithEventIDsSQL * Update sytest-blacklist * Update sytest-whitelist --- syncapi/storage/postgres/account_data_table.go | 5 +- .../storage/postgres/backward_extremities_table.go | 118 ++++++ .../storage/postgres/current_room_state_table.go | 9 +- syncapi/storage/postgres/invites_table.go | 5 +- .../storage/postgres/output_room_events_table.go | 197 ++++++---- .../postgres/output_room_events_topology_table.go | 188 ++++++++++ syncapi/storage/postgres/syncserver.go | 400 +++++++++++++++------ syncapi/storage/storage.go | 23 +- 8 files changed, 762 insertions(+), 183 deletions(-) create mode 100644 syncapi/storage/postgres/backward_extremities_table.go create mode 100644 syncapi/storage/postgres/output_room_events_topology_table.go (limited to 'syncapi/storage') diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 36ba88cd..94e6ac41 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" ) @@ -89,7 +90,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) insertAccountData( ctx context.Context, userID, roomID, dataType string, -) (pos int64, err error) { +) (pos types.StreamPosition, err error) { err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) return } @@ -97,7 +98,7 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountDataInRange( ctx context.Context, userID string, - oldPos, newPos int64, + oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart, ) (data map[string][]string, err error) { data = make(map[string][]string) diff --git a/syncapi/storage/postgres/backward_extremities_table.go b/syncapi/storage/postgres/backward_extremities_table.go new file mode 100644 index 00000000..476d26fa --- /dev/null +++ b/syncapi/storage/postgres/backward_extremities_table.go @@ -0,0 +1,118 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" +) + +const backwardExtremitiesSchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( + -- The 'room_id' key for the event. + room_id TEXT NOT NULL, + -- The event ID for the event. + event_id TEXT NOT NULL, + + PRIMARY KEY(room_id, event_id) +); +` + +const insertBackwardExtremitySQL = "" + + "INSERT INTO syncapi_backward_extremities (room_id, event_id)" + + " VALUES ($1, $2)" + +const selectBackwardExtremitiesForRoomSQL = "" + + "SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1" + +const isBackwardExtremitySQL = "" + + "SELECT EXISTS (" + + " SELECT TRUE FROM syncapi_backward_extremities" + + " WHERE room_id = $1 AND event_id = $2" + + ")" + +const deleteBackwardExtremitySQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND event_id = $2" + +type backwardExtremitiesStatements struct { + insertBackwardExtremityStmt *sql.Stmt + selectBackwardExtremitiesForRoomStmt *sql.Stmt + isBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremityStmt *sql.Stmt +} + +func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(backwardExtremitiesSchema) + if err != nil { + return + } + if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { + return + } + if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { + return + } + if s.isBackwardExtremityStmt, err = db.Prepare(isBackwardExtremitySQL); err != nil { + return + } + if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { + return + } + return +} + +func (s *backwardExtremitiesStatements) insertsBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (err error) { + _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + return +} + +func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (eventIDs []string, err error) { + eventIDs = make([]string, 0) + + rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + + for rows.Next() { + var eID string + if err = rows.Scan(&eID); err != nil { + return + } + + eventIDs = append(eventIDs, eID) + } + + return +} + +func (s *backwardExtremitiesStatements) isBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (isBE bool, err error) { + err = s.isBackwardExtremityStmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE) + return +} + +func (s *backwardExtremitiesStatements) deleteBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (err error) { + _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + return +} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 8b208043..816cbb44 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -87,10 +88,10 @@ const selectStateEventSQL = "" + const selectEventsWithEventIDsSQL = "" + // TODO: The session_id and transaction_id blanks are here because otherwise - // the rowsToStreamEvents expects there to be exactly four columns. We need to + // the rowsToStreamEvents expects there to be exactly five columns. We need to // figure out if these really need to be in the DB, and if so, we need a // better permanent fix for this. - neilalexander, 2 Jan 2020 - "SELECT added_at, event_json, 0 AS session_id, '' AS transaction_id" + + "SELECT added_at, event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + " FROM syncapi_current_room_state WHERE event_id = ANY($1)" type currentRoomStateStatements struct { @@ -213,7 +214,7 @@ func (s *currentRoomStateStatements) deleteRoomStateByEventID( func (s *currentRoomStateStatements) upsertRoomState( ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.Event, membership *string, addedAt int64, + event gomatrixserverlib.Event, membership *string, addedAt types.StreamPosition, ) error { // Parse content as JSON and search for an "url" key containsURL := false @@ -242,7 +243,7 @@ func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) selectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index ced4bfc4..ca4bbeb5 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -86,7 +87,7 @@ func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) { func (s *inviteEventsStatements) insertInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (streamPos int64, err error) { +) (streamPos types.StreamPosition, err error) { err = s.insertInviteEventStmt.QueryRowContext( ctx, inviteEvent.RoomID(), @@ -107,7 +108,7 @@ func (s *inviteEventsStatements) deleteInviteEvent( // selectInviteEventsInRange returns a map of room ID to invite event for the // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) selectInviteEventsInRange( - ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos int64, + ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, ) (map[string]gomatrixserverlib.Event, error) { stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index ca271593..be302d73 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -22,6 +22,7 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" "github.com/lib/pq" @@ -36,28 +37,35 @@ CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id; -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( - -- An incrementing ID which denotes the position in the log that this event resides at. - -- NB: 'serial' makes no guarantees to increment by 1 every time, only that it increments. - -- This isn't a problem for us since we just want to order by this field. - id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), - -- The event ID for the event - event_id TEXT NOT NULL, - -- The 'room_id' key for the event. - room_id TEXT NOT NULL, - -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. - event_json TEXT NOT NULL, - -- The event type e.g 'm.room.member'. - type TEXT NOT NULL, - -- The 'sender' property of the event. - sender TEXT NOT NULL, - -- true if the event content contains a url key. - contains_url BOOL NOT NULL, - -- A list of event IDs which represent a delta of added/removed room state. This can be NULL - -- if there is no delta. - add_state_ids TEXT[], - remove_state_ids TEXT[], - session_id BIGINT, -- The client session that sent the event, if any - transaction_id TEXT -- The transaction id used to send the event, if any + -- An incrementing ID which denotes the position in the log that this event resides at. + -- NB: 'serial' makes no guarantees to increment by 1 every time, only that it increments. + -- This isn't a problem for us since we just want to order by this field. + id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), + -- The event ID for the event + event_id TEXT NOT NULL, + -- The 'room_id' key for the event. + room_id TEXT NOT NULL, + -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. + event_json TEXT NOT NULL, + -- The event type e.g 'm.room.member'. + type TEXT NOT NULL, + -- The 'sender' property of the event. + sender TEXT NOT NULL, + -- true if the event content contains a url key. + contains_url BOOL NOT NULL, + -- A list of event IDs which represent a delta of added/removed room state. This can be NULL + -- if there is no delta. + add_state_ids TEXT[], + remove_state_ids TEXT[], + -- The client session that sent the event, if any + session_id BIGINT, + -- The transaction id used to send the event, if any + transaction_id TEXT, + -- Should the event be excluded from responses to /sync requests. Useful for + -- events retrieved through backfilling that have a position in the stream + -- that relates to the moment these were retrieved rather than the moment these + -- were emitted. + exclude_from_sync BOOL DEFAULT FALSE ); -- for event selection CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_events(event_id); @@ -65,23 +73,33 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + - "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" + - ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id" + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id" const selectEventsSQL = "" + - "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" const selectRecentEventsSQL = "" + - "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC LIMIT $4" +const selectRecentEventsForSyncSQL = "" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + + " ORDER BY id DESC LIMIT $4" + +const selectEarlyEventsSQL = "" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3" + + " ORDER BY id ASC LIMIT $4" + const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). const selectStateInRangeSQL = "" + - "SELECT id, event_json, add_state_ids, remove_state_ids" + + "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " AND ( $3::text[] IS NULL OR sender = ANY($3) )" + @@ -93,11 +111,13 @@ const selectStateInRangeSQL = "" + " LIMIT $8" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt } func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { @@ -117,6 +137,12 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { return } + if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil { + return + } + if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil { + return + } if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { return } @@ -127,9 +153,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) selectStateInRange( - ctx context.Context, txn *sql.Tx, oldPos, newPos int64, + ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, stateFilterPart *gomatrix.FilterPart, -) (map[string]map[string]bool, map[string]streamEvent, error) { +) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( @@ -149,19 +175,20 @@ func (s *outputRoomEventsStatements) selectStateInRange( // - For each room ID, build up an array of event IDs which represents cumulative adds/removes // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID // if they aren't in the event ID cache. We don't handle state deletion yet. - eventIDToEvent := make(map[string]streamEvent) + eventIDToEvent := make(map[string]types.StreamEvent) // RoomID => A set (map[string]bool) of state event IDs which are between the two positions stateNeeded := make(map[string]map[string]bool) for rows.Next() { var ( - streamPos int64 - eventBytes []byte - addIDs pq.StringArray - delIDs pq.StringArray + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + addIDs pq.StringArray + delIDs pq.StringArray ) - if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil { + if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil { return nil, nil, err } // Sanity check for deleted state and whine if we see it. We don't need to do anything @@ -192,9 +219,10 @@ func (s *outputRoomEventsStatements) selectStateInRange( } stateNeeded[ev.RoomID()] = needSet - eventIDToEvent[ev.EventID()] = streamEvent{ - Event: ev, - streamPosition: streamPos, + eventIDToEvent[ev.EventID()] = types.StreamEvent{ + Event: ev, + StreamPosition: streamPos, + ExcludeFromSync: excludeFromSync, } } @@ -221,8 +249,8 @@ func (s *outputRoomEventsStatements) selectMaxEventID( func (s *outputRoomEventsStatements) insertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string, - transactionID *api.TransactionID, -) (streamPos int64, err error) { + transactionID *api.TransactionID, excludeFromSync bool, +) (streamPos types.StreamPosition, err error) { var txnID *string var sessionID *int64 if transactionID != nil { @@ -251,16 +279,53 @@ func (s *outputRoomEventsStatements) insertEvent( pq.StringArray(removeState), sessionID, txnID, + excludeFromSync, ).Scan(&streamPos) return } -// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. +// selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'. +// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude +// from sync. func (s *outputRoomEventsStatements) selectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos int64, limit int, -) ([]streamEvent, error) { - stmt := common.TxStmt(txn, s.selectRecentEventsStmt) + roomID string, fromPos, toPos types.StreamPosition, limit int, + chronologicalOrder bool, onlySyncEvents bool, +) ([]types.StreamEvent, error) { + var stmt *sql.Stmt + if onlySyncEvents { + stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt) + } else { + stmt = common.TxStmt(txn, s.selectRecentEventsStmt) + } + + rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + return events, nil +} + +// selectEarlyEvents returns the earliest events in the given room, starting +// from a given position, up to a maximum of 'limit'. +func (s *outputRoomEventsStatements) selectEarlyEvents( + ctx context.Context, txn *sql.Tx, + roomID string, fromPos, toPos types.StreamPosition, limit int, +) ([]types.StreamEvent, error) { + stmt := common.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) if err != nil { return nil, err @@ -274,16 +339,16 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // necessarily the way the SQL query returns them, so a sort is necessary to // ensure the events are in the right order in the slice. sort.SliceStable(events, func(i int, j int) bool { - return events[i].streamPosition < events[j].streamPosition + return events[i].StreamPosition < events[j].StreamPosition }) return events, nil } -// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing -// from the database. +// selectEvents returns the events for the given event IDs. If an event is +// missing from the database, it will be omitted. func (s *outputRoomEventsStatements) selectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectEventsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { @@ -293,17 +358,18 @@ func (s *outputRoomEventsStatements) selectEvents( return rowsToStreamEvents(rows) } -func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { - var result []streamEvent +func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { + var result []types.StreamEvent for rows.Next() { var ( - streamPos int64 - eventBytes []byte - sessionID *int64 - txnID *string - transactionID *api.TransactionID + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + sessionID *int64 + txnID *string + transactionID *api.TransactionID ) - if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil { + if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { return nil, err } // TODO: Handle redacted events @@ -319,10 +385,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { } } - result = append(result, streamEvent{ - Event: ev, - streamPosition: streamPos, - transactionID: transactionID, + result = append(result, types.StreamEvent{ + Event: ev, + StreamPosition: streamPos, + TransactionID: transactionID, + ExcludeFromSync: excludeFromSync, }) } return result, nil diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go new file mode 100644 index 00000000..4a50b9a0 --- /dev/null +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -0,0 +1,188 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const outputRoomEventsTopologySchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( + -- The event ID for the event. + event_id TEXT PRIMARY KEY, + -- The place of the event in the room's topology. This can usually be determined + -- from the event's depth. + topological_position BIGINT NOT NULL, + -- The 'room_id' key for the event. + room_id TEXT NOT NULL +); +-- The topological order will be used in events selection and ordering +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); +` + +const insertEventInTopologySQL = "" + + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT DO NOTHING" + +const selectEventIDsInRangeASCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + + " ORDER BY topological_position ASC LIMIT $4" + +const selectEventIDsInRangeDESCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + + " ORDER BY topological_position DESC LIMIT $4" + +const selectPositionInTopologySQL = "" + + "SELECT topological_position FROM syncapi_output_room_events_topology" + + " WHERE event_id = $1" + +const selectMaxPositionInTopologySQL = "" + + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1" + +const selectEventIDsFromPositionSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position = $2" + +type outputRoomEventsTopologyStatements struct { + insertEventInTopologyStmt *sql.Stmt + selectEventIDsInRangeASCStmt *sql.Stmt + selectEventIDsInRangeDESCStmt *sql.Stmt + selectPositionInTopologyStmt *sql.Stmt + selectMaxPositionInTopologyStmt *sql.Stmt + selectEventIDsFromPositionStmt *sql.Stmt +} + +func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(outputRoomEventsTopologySchema) + if err != nil { + return + } + if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { + return + } + if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { + return + } + if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { + return + } + if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { + return + } + if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { + return + } + if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil { + return + } + return +} + +// insertEventInTopology inserts the given event in the room's topology, based +// on the event's depth. +func (s *outputRoomEventsTopologyStatements) insertEventInTopology( + ctx context.Context, event *gomatrixserverlib.Event, +) (err error) { + _, err = s.insertEventInTopologyStmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), + ) + return +} + +// selectEventIDsInRange selects the IDs of events which positions are within a +// given range in a given room's topological order. +// Returns an empty slice if no events match the given range. +func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( + ctx context.Context, roomID string, fromPos, toPos types.StreamPosition, + limit int, chronologicalOrder bool, +) (eventIDs []string, err error) { + // Decide on the selection's order according to whether chronological order + // is requested or not. + var stmt *sql.Stmt + if chronologicalOrder { + stmt = s.selectEventIDsInRangeASCStmt + } else { + stmt = s.selectEventIDsInRangeDESCStmt + } + + // Query the event IDs. + rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + if err == sql.ErrNoRows { + // If no event matched the request, return an empty slice. + return []string{}, nil + } else if err != nil { + return + } + + // Return the IDs. + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + + return +} + +// selectPositionInTopology returns the position of a given event in the +// topology of the room it belongs to. +func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( + ctx context.Context, eventID string, +) (pos types.StreamPosition, err error) { + err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos) + return +} + +func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( + ctx context.Context, roomID string, +) (pos types.StreamPosition, err error) { + err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos) + return +} + +// selectEventIDsFromPosition returns the IDs of all events that have a given +// position in the topology of a given room. +func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( + ctx context.Context, roomID string, pos types.StreamPosition, +) (eventIDs []string, err error) { + // Query the event IDs. + rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos) + if err == sql.ErrNoRows { + // If no event matched the request, return an empty slice. + return []string{}, nil + } else if err != nil { + return + } + // Return the IDs. + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + return +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 3a62d136..621aec95 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -20,7 +20,6 @@ import ( "database/sql" "encoding/json" "fmt" - "strconv" "time" "github.com/sirupsen/logrus" @@ -43,29 +42,24 @@ type stateDelta struct { membership string // The PDU stream position of the latest membership event for this user, if applicable. // Can be 0 if there is no membership event in this delta. - membershipPos int64 + membershipPos types.StreamPosition } -// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. -type streamEvent struct { - gomatrixserverlib.Event - streamPosition int64 - transactionID *api.TransactionID -} - -// SyncServerDatabase represents a sync server datasource which manages +// SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { db *sql.DB common.PartitionOffsetStatements - accountData accountDataStatements - events outputRoomEventsStatements - roomstate currentRoomStateStatements - invites inviteEventsStatements - typingCache *cache.TypingCache + accountData accountDataStatements + events outputRoomEventsStatements + roomstate currentRoomStateStatements + invites inviteEventsStatements + typingCache *cache.TypingCache + topology outputRoomEventsTopologyStatements + backwardExtremities backwardExtremitiesStatements } -// NewSyncServerDatabase creates a new sync server database +// NewSyncServerDatasource creates a new sync server database func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error @@ -87,6 +81,12 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er if err := d.invites.prepare(d.db); err != nil { return nil, err } + if err := d.topology.prepare(d.db); err != nil { + return nil, err + } + if err := d.backwardExtremities.prepare(d.db); err != nil { + return nil, err + } d.typingCache = cache.NewTypingCache() return &d, nil } @@ -109,7 +109,46 @@ func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([ // We don't include a device here as we only include transaction IDs in // incremental syncs. - return streamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(nil, streamEvents), nil +} + +func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, ev *gomatrixserverlib.Event) error { + // If the event is already known as a backward extremity, don't consider + // it as such anymore now that we have it. + isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, ev.RoomID(), ev.EventID()) + if err != nil { + return err + } + if isBackwardExtremity { + if err = d.backwardExtremities.deleteBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + return err + } + } + + // Check if we have all of the event's previous events. If an event is + // missing, add it to the room's backward extremities. + prevEvents, err := d.events.selectEvents(ctx, nil, ev.PrevEventIDs()) + if err != nil { + return err + } + var found bool + for _, eID := range ev.PrevEventIDs() { + found = false + for _, prevEv := range prevEvents { + if eID == prevEv.EventID() { + found = true + } + } + + // If the event is missing, consider it a backward extremity. + if !found { + if err = d.backwardExtremities.insertsBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + return err + } + } + } + + return nil } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races @@ -120,16 +159,26 @@ func (d *SyncServerDatasource) WriteEvent( ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, - transactionID *api.TransactionID, -) (pduPosition int64, returnErr error) { + transactionID *api.TransactionID, excludeFromSync bool, +) (pduPosition types.StreamPosition, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) + pos, err := d.events.insertEvent( + ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, + ) if err != nil { return err } pduPosition = pos + if err = d.topology.insertEventInTopology(ctx, ev); err != nil { + return err + } + + if err = d.handleBackwardExtremities(ctx, ev); err != nil { + return err + } + if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. return nil @@ -137,14 +186,15 @@ func (d *SyncServerDatasource) WriteEvent( return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) }) - return + + return pduPosition, returnErr } func (d *SyncServerDatasource) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, addedEvents []gomatrixserverlib.Event, - pduPosition int64, + pduPosition types.StreamPosition, ) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { @@ -196,14 +246,141 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } +// GetEventsInRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *SyncServerDatasource) GetEventsInRange( + ctx context.Context, + from, to *types.PaginationToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { + // If the pagination token's type is types.PaginationTokenTypeTopology, the + // events must be retrieved from the rooms' topology table rather than the + // table contaning the syncapi server's whole stream of events. + if from.Type == types.PaginationTokenTypeTopology { + // Determine the backward and forward limit, i.e. the upper and lower + // limits to the selection in the room's topology, from the direction. + var backwardLimit, forwardLimit types.StreamPosition + if backwardOrdering { + // Backward ordering is antichronological (latest event to oldest + // one). + backwardLimit = to.PDUPosition + forwardLimit = from.PDUPosition + } else { + // Forward ordering is chronological (oldest event to latest one). + backwardLimit = from.PDUPosition + forwardLimit = to.PDUPosition + } + + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.topology.selectEventIDsInRange( + ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, + ) + if err != nil { + return + } + + // Retrieve the events' contents using their IDs. + events, err = d.events.selectEvents(ctx, nil, eIDs) + return + } + + // If the pagination token's type is types.PaginationTokenTypeStream, the + // events must be retrieved from the table contaning the syncapi server's + // whole stream of events. + + if backwardOrdering { + // When using backward ordering, we want the most recent events first. + if events, err = d.events.selectRecentEvents( + ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, + ); err != nil { + return + } + } else { + // When using forward ordering, we want the least recent events first. + if events, err = d.events.selectEarlyEvents( + ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, + ); err != nil { + return + } + } + + return +} + // SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) { +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { return d.syncPositionTx(ctx, nil) } +// BackwardExtremitiesForRoom returns the event IDs of all of the backward +// extremities we know of for a given room. +func (d *SyncServerDatasource) BackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (backwardExtremities []string, err error) { + return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID) +} + +// MaxTopologicalPosition returns the highest topological position for a given +// room. +func (d *SyncServerDatasource) MaxTopologicalPosition( + ctx context.Context, roomID string, +) (types.StreamPosition, error) { + return d.topology.selectMaxPositionInTopology(ctx, roomID) +} + +// EventsAtTopologicalPosition returns all of the events matching a given +// position in the topology of a given room. +func (d *SyncServerDatasource) EventsAtTopologicalPosition( + ctx context.Context, roomID string, pos types.StreamPosition, +) ([]types.StreamEvent, error) { + eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos) + if err != nil { + return nil, err + } + + return d.events.selectEvents(ctx, nil, eIDs) +} + +func (d *SyncServerDatasource) EventPositionInTopology( + ctx context.Context, eventID string, +) (types.StreamPosition, error) { + return d.topology.selectPositionInTopology(ctx, eventID) +} + +// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. +func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { + return d.syncStreamPositionTx(ctx, nil) +} + +func (d *SyncServerDatasource) syncStreamPositionTx( + ctx context.Context, txn *sql.Tx, +) (types.StreamPosition, error) { + maxID, err := d.events.selectMaxEventID(ctx, txn) + if err != nil { + return 0, err + } + maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + if err != nil { + return 0, err + } + if maxAccountDataID > maxID { + maxID = maxAccountDataID + } + maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) + if err != nil { + return 0, err + } + if maxInviteID > maxID { + maxID = maxInviteID + } + return types.StreamPosition(maxID), nil +} + func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (sp types.SyncPosition, err error) { +) (sp types.PaginationToken, err error) { maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { @@ -223,10 +400,8 @@ func (d *SyncServerDatasource) syncPositionTx( if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp.PDUPosition = maxEventID - - sp.TypingPosition = d.typingCache.GetLatestSyncPosition() - + sp.PDUPosition = types.StreamPosition(maxEventID) + sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition()) return } @@ -235,7 +410,7 @@ func (d *SyncServerDatasource) syncPositionTx( func (d *SyncServerDatasource) addPDUDeltaToResponse( ctx context.Context, device authtypes.Device, - fromPos, toPos int64, + fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int, wantFullState bool, res *types.Response, @@ -287,7 +462,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( // addTypingDeltaToResponse adds all typing notifications to a sync response // since the specified position. func (d *SyncServerDatasource) addTypingDeltaToResponse( - since int64, + since types.PaginationToken, joinedRoomIDs []string, res *types.Response, ) error { @@ -296,7 +471,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( var err error for _, roomID := range joinedRoomIDs { if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter( - roomID, since, + roomID, int64(since.EDUTypingPosition), ); updated { ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MTyping, @@ -321,14 +496,14 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.SyncPosition, + fromPos, toPos types.PaginationToken, joinedRoomIDs []string, res *types.Response, ) (err error) { - if fromPos.TypingPosition != toPos.TypingPosition { + if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { err = d.addTypingDeltaToResponse( - fromPos.TypingPosition, joinedRoomIDs, res, + fromPos, joinedRoomIDs, res, ) } @@ -343,7 +518,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) IncrementalSync( ctx context.Context, device authtypes.Device, - fromPos, toPos types.SyncPosition, + fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { @@ -383,7 +558,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom int, ) ( res *types.Response, - toPos types.SyncPosition, + toPos types.PaginationToken, joinedRoomIDs []string, err error, ) { @@ -423,27 +598,37 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( } // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - var recentStreamEvents []streamEvent + var recentStreamEvents []types.StreamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, + ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, + numRecentEventsPerRoom, true, true, + //ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, ) if err != nil { return } + // Retrieve the backward topology position, i.e. the position of the + // oldest event in the room's topology. + var backwardTopologyPos types.StreamPosition + backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) + if err != nil { + return nil, types.PaginationToken{}, []string{}, err + } + if backwardTopologyPos-1 <= 0 { + backwardTopologyPos = types.StreamPosition(1) + } else { + backwardTopologyPos = backwardTopologyPos - 1 + } + // We don't include a device here as we don't need to send down // transaction IDs for complete syncs - recentEvents := streamEventsToEvents(nil, recentStreamEvents) - + recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { - // Use the short form of batch token for prev_batch - jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) - } else { - // Use the short form of batch token for prev_batch - jr.Timeline.PrevBatch = "1" - } + jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + ).String() jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync) @@ -471,7 +656,7 @@ func (d *SyncServerDatasource) CompleteSync( // Use a zero value SyncPosition for fromPos so all EDU states are added. err = d.addEDUDeltaToResponse( - types.SyncPosition{}, toPos, joinedRoomIDs, res, + types.PaginationToken{}, toPos, joinedRoomIDs, res, ) if err != nil { return nil, err @@ -496,7 +681,7 @@ var txReadOnlySnapshot = sql.TxOptions{ // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error func (d *SyncServerDatasource) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos int64, + ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart, ) (map[string][]string, error) { return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) @@ -510,7 +695,7 @@ func (d *SyncServerDatasource) GetAccountDataInRange( // Returns an error if there was an issue with the upsert func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, -) (int64, error) { +) (types.StreamPosition, error) { return d.accountData.insertAccountData(ctx, userID, roomID, dataType) } @@ -519,7 +704,7 @@ func (d *SyncServerDatasource) UpsertAccountData( // Returns an error if there was a problem communicating with the database. func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (int64, error) { +) (types.StreamPosition, error) { return d.invites.insertInviteEvent(ctx, inviteEvent) } @@ -542,26 +727,26 @@ func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallback // Returns the newly calculated sync position for typing notifications. func (d *SyncServerDatasource) AddTypingUser( userID, roomID string, expireTime *time.Time, -) int64 { - return d.typingCache.AddTypingUser(userID, roomID, expireTime) +) types.StreamPosition { + return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime)) } // RemoveTypingUser removes a typing user from the typing cache. // Returns the newly calculated sync position for typing notifications. func (d *SyncServerDatasource) RemoveTypingUser( userID, roomID string, -) int64 { - return d.typingCache.RemoveUser(userID, roomID) +) types.StreamPosition { + return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID)) } func (d *SyncServerDatasource) addInvitesToResponse( ctx context.Context, txn *sql.Tx, userID string, - fromPos, toPos int64, + fromPos, toPos types.StreamPosition, res *types.Response, ) error { invites, err := d.invites.selectInviteEventsInRange( - ctx, txn, userID, int64(fromPos), int64(toPos), + ctx, txn, userID, fromPos, toPos, ) if err != nil { return err @@ -577,12 +762,32 @@ func (d *SyncServerDatasource) addInvitesToResponse( return nil } +// Retrieve the backward topology position, i.e. the position of the +// oldest event in the room's topology. +func (d *SyncServerDatasource) getBackwardTopologyPos( + ctx context.Context, + events []types.StreamEvent, +) (pos types.StreamPosition, err error) { + if len(events) > 0 { + pos, err = d.topology.selectPositionInTopology(ctx, events[0].EventID()) + if err != nil { + return + } + } + if pos-1 <= 0 { + pos = types.StreamPosition(1) + } else { + pos = pos - 1 + } + return +} + // addRoomDeltaToResponse adds a room state delta to a sync response func (d *SyncServerDatasource) addRoomDeltaToResponse( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos int64, + fromPos, toPos types.StreamPosition, delta stateDelta, numRecentEventsPerRoom int, res *types.Response, @@ -598,38 +803,28 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( endPos = delta.membershipPos } recentStreamEvents, err := d.events.selectRecentEvents( - ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom, + ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos), + numRecentEventsPerRoom, true, true, ) if err != nil { return err } - recentEvents := streamEventsToEvents(device, recentStreamEvents) + recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - var prevPDUPos int64 - - if len(recentEvents) == 0 { - if len(delta.stateEvents) == 0 { - // Don't bother appending empty room entries - return nil - } - - // If full_state=true and since is already up to date, then we'll have - // state events but no recent events. - prevPDUPos = toPos - 1 - } else { - prevPDUPos = recentStreamEvents[0].streamPosition - 1 - } - - if prevPDUPos <= 0 { - prevPDUPos = 1 + var backwardTopologyPos types.StreamPosition + backwardTopologyPos, err = d.getBackwardTopologyPos(ctx, recentStreamEvents) + if err != nil { + return err } switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() - // Use the short form of batch token for prev_batch - jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) + + jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + ).String() jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -640,8 +835,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - // Use the short form of batch token for prev_batch - lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) + lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + ).String() lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -656,9 +852,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( func (d *SyncServerDatasource) fetchStateEvents( ctx context.Context, txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, - eventIDToEvent map[string]streamEvent, -) (map[string][]streamEvent, error) { - stateBetween := make(map[string][]streamEvent) + eventIDToEvent map[string]types.StreamEvent, +) (map[string][]types.StreamEvent, error) { + stateBetween := make(map[string][]types.StreamEvent) missingEvents := make(map[string][]string) for roomID, ids := range roomIDToEventIDSet { events := stateBetween[roomID] @@ -700,7 +896,7 @@ func (d *SyncServerDatasource) fetchStateEvents( func (d *SyncServerDatasource) fetchMissingStateEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. events, err := d.events.selectEvents(ctx, txn, eventIDs) @@ -743,7 +939,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents( // A list of joined room IDs is also returned in case the caller needs it. func (d *SyncServerDatasource) getStateDeltas( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos int64, userID string, + fromPos, toPos types.StreamPosition, userID string, stateFilterPart *gomatrix.FilterPart, ) ([]stateDelta, []string, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 @@ -776,7 +972,7 @@ func (d *SyncServerDatasource) getStateDeltas( if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership == gomatrixserverlib.Join { // send full room state down instead of a delta - var s []streamEvent + var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart) if err != nil { return nil, nil, err @@ -787,8 +983,8 @@ func (d *SyncServerDatasource) getStateDeltas( deltas = append(deltas, stateDelta{ membership: membership, - membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(device, stateStreamEvents), + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) break @@ -804,7 +1000,7 @@ func (d *SyncServerDatasource) getStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, stateDelta{ membership: gomatrixserverlib.Join, - stateEvents: streamEventsToEvents(device, state[joinedRoomID]), + stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), roomID: joinedRoomID, }) } @@ -818,7 +1014,7 @@ func (d *SyncServerDatasource) getStateDeltas( // updates for other rooms. func (d *SyncServerDatasource) getStateDeltasForFullStateSync( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos int64, userID string, + fromPos, toPos types.StreamPosition, userID string, stateFilterPart *gomatrix.FilterPart, ) ([]stateDelta, []string, error) { joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) @@ -837,7 +1033,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( } deltas = append(deltas, stateDelta{ membership: gomatrixserverlib.Join, - stateEvents: streamEventsToEvents(device, s), + stateEvents: d.StreamEventsToEvents(device, s), roomID: joinedRoomID, }) } @@ -858,8 +1054,8 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. deltas = append(deltas, stateDelta{ membership: membership, - membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(device, stateStreamEvents), + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) } @@ -875,29 +1071,29 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( func (d *SyncServerDatasource) currentStateStreamEventsForRoom( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrix.FilterPart, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) if err != nil { return nil, err } - s := make([]streamEvent, len(allState)) + s := make([]types.StreamEvent, len(allState)) for i := 0; i < len(s); i++ { - s[i] = streamEvent{Event: allState[i], streamPosition: 0} + s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0} } return s, nil } -// streamEventsToEvents converts streamEvent to Event. If device is non-nil and +// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. -func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event { +func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event { out := make([]gomatrixserverlib.Event, len(in)) for i := 0; i < len(in); i++ { out[i] = in[i].Event - if device != nil && in[i].transactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID { + if device != nil && in[i].TransactionID != nil { + if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( - "transaction_id", in[i].transactionID.TransactionID, + "transaction_id", in[i].TransactionID.TransactionID, ) if err != nil { logrus.WithFields(logrus.Fields{ diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index 5db4b3a1..4e8a2c83 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -33,19 +33,26 @@ type Database interface { common.PartitionStorer AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) - WriteEvent(ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID) (pduPosition int64, returnErr error) + WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error) - SyncPosition(ctx context.Context) (types.SyncPosition, error) - IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.SyncPosition, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + SyncPosition(ctx context.Context) (types.PaginationToken, error) + IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) - GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos int64, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error) - UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (int64, error) - AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (int64, error) + GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error) + UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error) + AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (types.StreamPosition, error) RetireInviteEvent(ctx context.Context, inviteEventID string) error SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) - AddTypingUser(userID, roomID string, expireTime *time.Time) int64 - RemoveTypingUser(userID, roomID string) int64 + AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition + RemoveTypingUser(userID, roomID string) types.StreamPosition + GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error) + EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error) + BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error) + MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error) + StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event + SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) } // NewPublicRoomsServerDatabase opens a database connection. -- cgit v1.2.3