diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-01-23 17:51:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-23 17:51:10 +0000 |
commit | 49f760a30b6496c8b3e1ceaf98dccc4376f6605d (patch) | |
tree | b00d3fc17144cc83df1e5c7b8d1080ca19041243 /syncapi | |
parent | 43ecf8d1f909f4eb71bba93f6e7a57db59ec5941 (diff) |
CS API: Support for /messages, fixes for /sync (#847)
* Merge forward
* Tidy up a bit
* TODO: What to do with NextBatch here?
* Replace SyncPosition with PaginationToken throughout syncapi
* Fix PaginationTokens
* Fix lint errors
* Add a couple of missing functions into the syncapi external storage interface
* Some updates based on review comments from @babolivier
* Some updates based on review comments from @babolivier
* argh whitespacing
* Fix opentracing span
* Remove dead code
* Don't overshadow err (fix lint issue)
* Handle extremities after inserting event into topology
* Try insert event topology as ON CONFLICT DO NOTHING
* Prevent OOB error in addRoomDeltaToResponse
* Thwarted by gocyclo again
* Fix NewPaginationTokenFromString, define unit test for it
* Update pagination token test
* Update sytest-whitelist
* Hopefully fix some of the sync batch tokens
* Remove extraneous sync position func
* Revert to topology tokens in addRoomDeltaToResponse etc
* Fix typo
* Remove prevPDUPos as dead now that backwardTopologyPos is used instead
* Fix selectEventsWithEventIDsSQL
* Update sytest-blacklist
* Update sytest-whitelist
Diffstat (limited to 'syncapi')
-rw-r--r-- | syncapi/consumers/clientapi.go | 2 | ||||
-rw-r--r-- | syncapi/consumers/roomserver.go | 5 | ||||
-rw-r--r-- | syncapi/consumers/typingserver.go | 11 | ||||
-rw-r--r-- | syncapi/routing/messages.go | 482 | ||||
-rw-r--r-- | syncapi/routing/routing.go | 18 | ||||
-rw-r--r-- | syncapi/storage/postgres/account_data_table.go | 5 | ||||
-rw-r--r-- | syncapi/storage/postgres/backward_extremities_table.go | 118 | ||||
-rw-r--r-- | syncapi/storage/postgres/current_room_state_table.go | 9 | ||||
-rw-r--r-- | syncapi/storage/postgres/invites_table.go | 5 | ||||
-rw-r--r-- | syncapi/storage/postgres/output_room_events_table.go | 197 | ||||
-rw-r--r-- | syncapi/storage/postgres/output_room_events_topology_table.go | 188 | ||||
-rw-r--r-- | syncapi/storage/postgres/syncserver.go | 400 | ||||
-rw-r--r-- | syncapi/storage/storage.go | 23 | ||||
-rw-r--r-- | syncapi/sync/notifier.go | 10 | ||||
-rw-r--r-- | syncapi/sync/notifier_test.go | 40 | ||||
-rw-r--r-- | syncapi/sync/request.go | 39 | ||||
-rw-r--r-- | syncapi/sync/requestpool.go | 10 | ||||
-rw-r--r-- | syncapi/sync/userstream.go | 10 | ||||
-rw-r--r-- | syncapi/syncapi.go | 6 | ||||
-rw-r--r-- | syncapi/types/types.go | 153 | ||||
-rw-r--r-- | syncapi/types/types_test.go | 52 |
21 files changed, 1502 insertions, 281 deletions
diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index ed39cd2d..17f2c522 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -90,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error }).Panicf("could not save account data") } - s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos}) + s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos}) return nil } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index cde2f508..ba1b7dc5 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -133,6 +133,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( msg.AddsStateEventIDs, msg.RemovesStateEventIDs, msg.TransactionID, + false, ) if err != nil { // panic rather than continue with an inconsistent database @@ -144,7 +145,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( }).Panicf("roomserver output log: write event failure") return nil } - s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos}) return nil } @@ -161,7 +162,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( }).Panicf("roomserver output log: write invite failure") return nil } - s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos}) + s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos}) return nil } diff --git a/syncapi/consumers/typingserver.go b/syncapi/consumers/typingserver.go index 392f7987..36925441 100644 --- a/syncapi/consumers/typingserver.go +++ b/syncapi/consumers/typingserver.go @@ -63,7 +63,12 @@ func NewOutputTypingEventConsumer( // Start consuming from typing api func (s *OutputTypingEventConsumer) Start() error { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { - s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition}) + s.notifier.OnNewEvent( + nil, roomID, nil, + types.PaginationToken{ + EDUTypingPosition: types.StreamPosition(latestSyncPosition), + }, + ) }) return s.typingConsumer.Start() @@ -83,7 +88,7 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error "typing": output.Event.Typing, }).Debug("received data from typing server") - var typingPos int64 + var typingPos types.StreamPosition typingEvent := output.Event if typingEvent.Typing { typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime) @@ -91,6 +96,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) } - s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos}) + s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos}) return nil } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go new file mode 100644 index 00000000..26f48ca4 --- /dev/null +++ b/syncapi/routing/messages.go @@ -0,0 +1,482 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "net/http" + "sort" + "strconv" + + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" +) + +type messagesReq struct { + ctx context.Context + db storage.Database + queryAPI api.RoomserverQueryAPI + federation *gomatrixserverlib.FederationClient + cfg *config.Dendrite + roomID string + from *types.PaginationToken + to *types.PaginationToken + wasToProvided bool + limit int + backwardOrdering bool +} + +type messagesResp struct { + Start string `json:"start"` + End string `json:"end"` + Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` +} + +const defaultMessagesLimit = 10 + +// OnIncomingMessagesRequest implements the /messages endpoint from the +// client-server API. +// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages +func OnIncomingMessagesRequest( + req *http.Request, db storage.Database, roomID string, + federation *gomatrixserverlib.FederationClient, + queryAPI api.RoomserverQueryAPI, + cfg *config.Dendrite, +) util.JSONResponse { + var err error + + // Extract parameters from the request's URL. + // Pagination tokens. + from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from")) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), + } + } + + // Direction to return events from. + dir := req.URL.Query().Get("dir") + if dir != "b" && dir != "f" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"), + } + } + // A boolean is easier to handle in this case, especially since dir is sure + // to have one of the two accepted values (so dir == "f" <=> !backwardOrdering). + backwardOrdering := (dir == "b") + + // Pagination tokens. To is optional, and its default value depends on the + // direction ("b" or "f"). + var to *types.PaginationToken + wasToProvided := true + if s := req.URL.Query().Get("to"); len(s) > 0 { + to, err = types.NewPaginationTokenFromString(s) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), + } + } + } else { + // If "to" isn't provided, it defaults to either the earliest stream + // position (if we're going backward) or to the latest one (if we're + // going forward). + to, err = setToDefault(req.Context(), db, backwardOrdering, roomID) + if err != nil { + return httputil.LogThenError(req, err) + } + wasToProvided = false + } + + // Maximum number of events to return; defaults to 10. + limit := defaultMessagesLimit + if len(req.URL.Query().Get("limit")) > 0 { + limit, err = strconv.Atoi(req.URL.Query().Get("limit")) + + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()), + } + } + } + // TODO: Implement filtering (#587) + + // Check the room ID's format. + if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()), + } + } + + mReq := messagesReq{ + ctx: req.Context(), + db: db, + queryAPI: queryAPI, + federation: federation, + cfg: cfg, + roomID: roomID, + from: from, + to: to, + wasToProvided: wasToProvided, + limit: limit, + backwardOrdering: backwardOrdering, + } + + clientEvents, start, end, err := mReq.retrieveEvents() + if err != nil { + return httputil.LogThenError(req, err) + } + + // Respond with the events. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: messagesResp{ + Chunk: clientEvents, + Start: start.String(), + End: end.String(), + }, + } +} + +// retrieveEvents retrieve events from the local database for a request on +// /messages. If there's not enough events to retrieve, it asks another +// homeserver in the room for older events. +// Returns an error if there was an issue talking to the database or with the +// remote homeserver. +func (r *messagesReq) retrieveEvents() ( + clientEvents []gomatrixserverlib.ClientEvent, start, + end *types.PaginationToken, err error, +) { + // Retrieve the events from the local database. + streamEvents, err := r.db.GetEventsInRange( + r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, + ) + if err != nil { + return + } + + var events []gomatrixserverlib.Event + + // There can be two reasons for streamEvents to be empty: either we've + // reached the oldest event in the room (or the most recent one, depending + // on the ordering), or we've reached a backward extremity. + if len(streamEvents) == 0 { + if events, err = r.handleEmptyEventsSlice(); err != nil { + return + } + } else { + if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil { + return + } + } + + // If we didn't get any event, we don't need to proceed any further. + if len(events) == 0 { + return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil + } + + // Sort the events to ensure we send them in the right order. We currently + // do that based on the event's timestamp. + if r.backwardOrdering { + sort.SliceStable(events, func(i int, j int) bool { + // Backward ordering is antichronological (latest event to oldest + // one). + return sortEvents(&(events[j]), &(events[i])) + }) + } else { + sort.SliceStable(events, func(i int, j int) bool { + // Forward ordering is chronological (oldest event to latest one). + return sortEvents(&(events[i]), &(events[j])) + }) + } + + // Convert all of the events into client events. + clientEvents = gomatrixserverlib.ToClientEvents(events, gomatrixserverlib.FormatAll) + // Get the position of the first and the last event in the room's topology. + // This position is currently determined by the event's depth, so we could + // also use it instead of retrieving from the database. However, if we ever + // change the way topological positions are defined (as depth isn't the most + // reliable way to define it), it would be easier and less troublesome to + // only have to change it in one place, i.e. the database. + startPos, err := r.db.EventPositionInTopology( + r.ctx, streamEvents[0].EventID(), + ) + if err != nil { + return + } + endPos, err := r.db.EventPositionInTopology( + r.ctx, streamEvents[len(streamEvents)-1].EventID(), + ) + if err != nil { + return + } + // Generate pagination tokens to send to the client using the positions + // retrieved previously. + start = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, startPos, 0, + ) + end = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, endPos, 0, + ) + + if r.backwardOrdering { + // A stream/topological position is a cursor located between two events. + // While they are identified in the code by the event on their right (if + // we consider a left to right chronological order), tokens need to refer + // to them by the event on their left, therefore we need to decrement the + // end position we send in the response if we're going backward. + end.PDUPosition-- + } + + // The lowest token value is 1, therefore we need to manually set it to that + // value if we're below it. + if end.PDUPosition < types.StreamPosition(1) { + end.PDUPosition = types.StreamPosition(1) + } + + return clientEvents, start, end, err +} + +// handleEmptyEventsSlice handles the case where the initial request to the +// database returned an empty slice of events. It does so by checking whether +// the set is empty because we've reached a backward extremity, and if that is +// the case, by retrieving as much events as requested by backfilling from +// another homeserver. +// Returns an error if there was an issue talking with the database or +// backfilling. +func (r *messagesReq) handleEmptyEventsSlice() ( + events []gomatrixserverlib.Event, err error, +) { + backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + + // Check if we have backward extremities for this room. + if len(backwardExtremities) > 0 { + // If so, retrieve as much events as needed through backfilling. + events, err = r.backfill(backwardExtremities, r.limit) + if err != nil { + return + } + } else { + // If not, it means the slice was empty because we reached the room's + // creation, so return an empty slice. + events = []gomatrixserverlib.Event{} + } + + return +} + +// handleNonEmptyEventsSlice handles the case where the initial request to the +// database returned a non-empty slice of events. It does so by checking whether +// events are missing from the expected result, and retrieve missing events +// through backfilling if needed. +// Returns an error if there was an issue while backfilling. +func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent) ( + events []gomatrixserverlib.Event, err error, +) { + // Check if we have enough events. + isSetLargeEnough := true + if len(streamEvents) < r.limit { + if r.backwardOrdering { + if r.wasToProvided { + // The condition in the SQL query is a strict "greater than" so + // we need to check against to-1. + streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition) + isSetLargeEnough = (r.to.PDUPosition-1 == streamPos) + } + } else { + streamPos := types.StreamPosition(streamEvents[0].StreamPosition) + isSetLargeEnough = (r.from.PDUPosition-1 == streamPos) + } + } + + // Check if the slice contains a backward extremity. + backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + if err != nil { + return + } + + // Backfill is needed if we've reached a backward extremity and need more + // events. It's only needed if the direction is backward. + if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering { + var pdus []gomatrixserverlib.Event + // Only ask the remote server for enough events to reach the limit. + pdus, err = r.backfill(backwardExtremities, r.limit-len(streamEvents)) + if err != nil { + return + } + + // Append the PDUs to the list to send back to the client. + events = append(events, pdus...) + } + + // Append the events ve previously retrieved locally. + events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...) + + return +} + +// containsBackwardExtremity checks if a slice of StreamEvent contains a +// backward extremity. It does so by selecting the earliest event in the slice +// and by checking the presence in the database of all of its parent events, and +// considers the event itself a backward extremity if at least one of the parent +// events doesn't exist in the database. +// Returns an error if there was an issue with talking to the database. +func (r *messagesReq) containsBackwardExtremity(events []types.StreamEvent) (bool, error) { + // Select the earliest retrieved event. + var ev *types.StreamEvent + if r.backwardOrdering { + ev = &(events[len(events)-1]) + } else { + ev = &(events[0]) + } + // Get the earliest retrieved event's parents. + prevIDs := ev.PrevEventIDs() + prevs, err := r.db.Events(r.ctx, prevIDs) + if err != nil { + return false, nil + } + // Check if we have all of the events we requested. If not, it means we've + // reached a backward extremity. + var eventInDB bool + var id string + // Iterate over the IDs we used in the request. + for _, id = range prevIDs { + eventInDB = false + // Iterate over the events we got in response. + for _, ev := range prevs { + if ev.EventID() == id { + eventInDB = true + } + } + // One occurrence of one the event's parents not being present in the + // database is enough to say that the event is a backward extremity. + if !eventInDB { + return true, nil + } + } + + return false, nil +} + +// backfill performs a backfill request over the federation on another +// homeserver in the room. +// See: https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid +// It also stores the PDUs retrieved from the remote homeserver's response to +// the database. +// Returns with an empty string if the remote homeserver didn't return with any +// event, or if there is no remote homeserver to contact. +// Returns an error if there was an issue with retrieving the list of servers in +// the room or sending the request. +func (r *messagesReq) backfill(fromEventIDs []string, limit int) ([]gomatrixserverlib.Event, error) { + // Query the list of servers in the room when one of the backward extremities + // was sent. + var serversResponse api.QueryServersInRoomAtEventResponse + serversRequest := api.QueryServersInRoomAtEventRequest{ + RoomID: r.roomID, + EventID: fromEventIDs[0], + } + if err := r.queryAPI.QueryServersInRoomAtEvent(r.ctx, &serversRequest, &serversResponse); err != nil { + return nil, err + } + + // Use the first server from the response, except if that server is us. + // In that case, use the second one if the roomserver responded with + // enough servers. If not, use an empty string to prevent the backfill + // from happening as there's no server to direct the request towards. + // TODO: Be smarter at selecting the server to direct the request + // towards. + srvToBackfillFrom := serversResponse.Servers[0] + if srvToBackfillFrom == r.cfg.Matrix.ServerName { + if len(serversResponse.Servers) > 1 { + srvToBackfillFrom = serversResponse.Servers[1] + } else { + srvToBackfillFrom = gomatrixserverlib.ServerName("") + log.Warn("Not enough servers to backfill from") + } + } + + pdus := make([]gomatrixserverlib.Event, 0) + + // If the roomserver responded with at least one server that isn't us, + // send it a request for backfill. + if len(srvToBackfillFrom) > 0 { + txn, err := r.federation.Backfill( + r.ctx, srvToBackfillFrom, r.roomID, limit, fromEventIDs, + ) + if err != nil { + return nil, err + } + + pdus = txn.PDUs + + // Store the events in the database, while marking them as unfit to show + // up in responses to sync requests. + for _, pdu := range pdus { + if _, err = r.db.WriteEvent( + r.ctx, &pdu, []gomatrixserverlib.Event{}, []string{}, []string{}, + nil, true, + ); err != nil { + return nil, err + } + } + } + + return pdus, nil +} + +// setToDefault returns the default value for the "to" query parameter of a +// request to /messages if not provided. It defaults to either the earliest +// topological position (if we're going backward) or to the latest one (if we're +// going forward). +// Returns an error if there was an issue with retrieving the latest position +// from the database +func setToDefault( + ctx context.Context, db storage.Database, backwardOrdering bool, + roomID string, +) (to *types.PaginationToken, err error) { + if backwardOrdering { + to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 1, 0) + } else { + var pos types.StreamPosition + pos, err = db.MaxTopologicalPosition(ctx, roomID) + if err != nil { + return + } + + to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, 0) + } + + return +} + +// sortEvents is a function to give to sort.SliceStable, and compares the +// timestamp of two Matrix events. +// Returns true if the first event happened before the second one, false +// otherwise. +func sortEvents(e1 *gomatrixserverlib.Event, e2 *gomatrixserverlib.Event) bool { + t := e1.OriginServerTS().Time() + return e2.OriginServerTS().Time().After(t) +} diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index bd9389bd..8916565d 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -22,8 +22,11 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/common/config" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/sync" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -34,7 +37,12 @@ const pathPrefixR0 = "/_matrix/client/r0" // Due to Setup being used to call many other functions, a gocyclo nolint is // applied: // nolint: gocyclo -func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, deviceDB *devices.Database) { +func Setup( + apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, + deviceDB *devices.Database, federation *gomatrixserverlib.FederationClient, + queryAPI api.RoomserverQueryAPI, + cfg *config.Dendrite, +) { r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ @@ -71,4 +79,12 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, d } return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"]) })).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/rooms/{roomID}/messages", common.MakeAuthAPI("room_messages", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars, err := common.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, queryAPI, cfg) + })).Methods(http.MethodGet, http.MethodOptions) } diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index 36ba88cd..94e6ac41 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" ) @@ -89,7 +90,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func (s *accountDataStatements) insertAccountData( ctx context.Context, userID, roomID, dataType string, -) (pos int64, err error) { +) (pos types.StreamPosition, err error) { err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) return } @@ -97,7 +98,7 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountDataInRange( ctx context.Context, userID string, - oldPos, newPos int64, + oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart, ) (data map[string][]string, err error) { data = make(map[string][]string) diff --git a/syncapi/storage/postgres/backward_extremities_table.go b/syncapi/storage/postgres/backward_extremities_table.go new file mode 100644 index 00000000..476d26fa --- /dev/null +++ b/syncapi/storage/postgres/backward_extremities_table.go @@ -0,0 +1,118 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" +) + +const backwardExtremitiesSchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_backward_extremities ( + -- The 'room_id' key for the event. + room_id TEXT NOT NULL, + -- The event ID for the event. + event_id TEXT NOT NULL, + + PRIMARY KEY(room_id, event_id) +); +` + +const insertBackwardExtremitySQL = "" + + "INSERT INTO syncapi_backward_extremities (room_id, event_id)" + + " VALUES ($1, $2)" + +const selectBackwardExtremitiesForRoomSQL = "" + + "SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1" + +const isBackwardExtremitySQL = "" + + "SELECT EXISTS (" + + " SELECT TRUE FROM syncapi_backward_extremities" + + " WHERE room_id = $1 AND event_id = $2" + + ")" + +const deleteBackwardExtremitySQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND event_id = $2" + +type backwardExtremitiesStatements struct { + insertBackwardExtremityStmt *sql.Stmt + selectBackwardExtremitiesForRoomStmt *sql.Stmt + isBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremityStmt *sql.Stmt +} + +func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(backwardExtremitiesSchema) + if err != nil { + return + } + if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { + return + } + if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { + return + } + if s.isBackwardExtremityStmt, err = db.Prepare(isBackwardExtremitySQL); err != nil { + return + } + if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { + return + } + return +} + +func (s *backwardExtremitiesStatements) insertsBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (err error) { + _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + return +} + +func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (eventIDs []string, err error) { + eventIDs = make([]string, 0) + + rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + if err != nil { + return + } + + for rows.Next() { + var eID string + if err = rows.Scan(&eID); err != nil { + return + } + + eventIDs = append(eventIDs, eID) + } + + return +} + +func (s *backwardExtremitiesStatements) isBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (isBE bool, err error) { + err = s.isBackwardExtremityStmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE) + return +} + +func (s *backwardExtremitiesStatements) deleteBackwardExtremity( + ctx context.Context, roomID, eventID string, +) (err error) { + _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID) + return +} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 8b208043..816cbb44 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" ) @@ -87,10 +88,10 @@ const selectStateEventSQL = "" + const selectEventsWithEventIDsSQL = "" + // TODO: The session_id and transaction_id blanks are here because otherwise - // the rowsToStreamEvents expects there to be exactly four columns. We need to + // the rowsToStreamEvents expects there to be exactly five columns. We need to // figure out if these really need to be in the DB, and if so, we need a // better permanent fix for this. - neilalexander, 2 Jan 2020 - "SELECT added_at, event_json, 0 AS session_id, '' AS transaction_id" + + "SELECT added_at, event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + " FROM syncapi_current_room_state WHERE event_id = ANY($1)" type currentRoomStateStatements struct { @@ -213,7 +214,7 @@ func (s *currentRoomStateStatements) deleteRoomStateByEventID( func (s *currentRoomStateStatements) upsertRoomState( ctx context.Context, txn *sql.Tx, - event gomatrixserverlib.Event, membership *string, addedAt int64, + event gomatrixserverlib.Event, membership *string, addedAt types.StreamPosition, ) error { // Parse content as JSON and search for an "url" key containsURL := false @@ -242,7 +243,7 @@ func (s *currentRoomStateStatements) upsertRoomState( func (s *currentRoomStateStatements) selectEventsWithEventIDs( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index ced4bfc4..ca4bbeb5 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -86,7 +87,7 @@ func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) { func (s *inviteEventsStatements) insertInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (streamPos int64, err error) { +) (streamPos types.StreamPosition, err error) { err = s.insertInviteEventStmt.QueryRowContext( ctx, inviteEvent.RoomID(), @@ -107,7 +108,7 @@ func (s *inviteEventsStatements) deleteInviteEvent( // selectInviteEventsInRange returns a map of room ID to invite event for the // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) selectInviteEventsInRange( - ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos int64, + ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, ) (map[string]gomatrixserverlib.Event, error) { stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index ca271593..be302d73 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -22,6 +22,7 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrix" "github.com/lib/pq" @@ -36,28 +37,35 @@ CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id; -- Stores output room events received from the roomserver. CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( - -- An incrementing ID which denotes the position in the log that this event resides at. - -- NB: 'serial' makes no guarantees to increment by 1 every time, only that it increments. - -- This isn't a problem for us since we just want to order by this field. - id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), - -- The event ID for the event - event_id TEXT NOT NULL, - -- The 'room_id' key for the event. - room_id TEXT NOT NULL, - -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. - event_json TEXT NOT NULL, - -- The event type e.g 'm.room.member'. - type TEXT NOT NULL, - -- The 'sender' property of the event. - sender TEXT NOT NULL, - -- true if the event content contains a url key. - contains_url BOOL NOT NULL, - -- A list of event IDs which represent a delta of added/removed room state. This can be NULL - -- if there is no delta. - add_state_ids TEXT[], - remove_state_ids TEXT[], - session_id BIGINT, -- The client session that sent the event, if any - transaction_id TEXT -- The transaction id used to send the event, if any + -- An incrementing ID which denotes the position in the log that this event resides at. + -- NB: 'serial' makes no guarantees to increment by 1 every time, only that it increments. + -- This isn't a problem for us since we just want to order by this field. + id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), + -- The event ID for the event + event_id TEXT NOT NULL, + -- The 'room_id' key for the event. + room_id TEXT NOT NULL, + -- The JSON for the event. Stored as TEXT because this should be valid UTF-8. + event_json TEXT NOT NULL, + -- The event type e.g 'm.room.member'. + type TEXT NOT NULL, + -- The 'sender' property of the event. + sender TEXT NOT NULL, + -- true if the event content contains a url key. + contains_url BOOL NOT NULL, + -- A list of event IDs which represent a delta of added/removed room state. This can be NULL + -- if there is no delta. + add_state_ids TEXT[], + remove_state_ids TEXT[], + -- The client session that sent the event, if any + session_id BIGINT, + -- The transaction id used to send the event, if any + transaction_id TEXT, + -- Should the event be excluded from responses to /sync requests. Useful for + -- events retrieved through backfilling that have a position in the stream + -- that relates to the moment these were retrieved rather than the moment these + -- were emitted. + exclude_from_sync BOOL DEFAULT FALSE ); -- for event selection CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_events(event_id); @@ -65,23 +73,33 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + - "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" + - ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id" + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" + + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id" const selectEventsSQL = "" + - "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" const selectRecentEventsSQL = "" + - "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC LIMIT $4" +const selectRecentEventsForSyncSQL = "" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + + " ORDER BY id DESC LIMIT $4" + +const selectEarlyEventsSQL = "" + + "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + + " WHERE room_id = $1 AND id > $2 AND id <= $3" + + " ORDER BY id ASC LIMIT $4" + const selectMaxEventIDSQL = "" + "SELECT MAX(id) FROM syncapi_output_room_events" // In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id). const selectStateInRangeSQL = "" + - "SELECT id, event_json, add_state_ids, remove_state_ids" + + "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + " AND ( $3::text[] IS NULL OR sender = ANY($3) )" + @@ -93,11 +111,13 @@ const selectStateInRangeSQL = "" + " LIMIT $8" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt } func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { @@ -117,6 +137,12 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { return } + if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil { + return + } + if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil { + return + } if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { return } @@ -127,9 +153,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) selectStateInRange( - ctx context.Context, txn *sql.Tx, oldPos, newPos int64, + ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, stateFilterPart *gomatrix.FilterPart, -) (map[string]map[string]bool, map[string]streamEvent, error) { +) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( @@ -149,19 +175,20 @@ func (s *outputRoomEventsStatements) selectStateInRange( // - For each room ID, build up an array of event IDs which represents cumulative adds/removes // For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID // if they aren't in the event ID cache. We don't handle state deletion yet. - eventIDToEvent := make(map[string]streamEvent) + eventIDToEvent := make(map[string]types.StreamEvent) // RoomID => A set (map[string]bool) of state event IDs which are between the two positions stateNeeded := make(map[string]map[string]bool) for rows.Next() { var ( - streamPos int64 - eventBytes []byte - addIDs pq.StringArray - delIDs pq.StringArray + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + addIDs pq.StringArray + delIDs pq.StringArray ) - if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil { + if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil { return nil, nil, err } // Sanity check for deleted state and whine if we see it. We don't need to do anything @@ -192,9 +219,10 @@ func (s *outputRoomEventsStatements) selectStateInRange( } stateNeeded[ev.RoomID()] = needSet - eventIDToEvent[ev.EventID()] = streamEvent{ - Event: ev, - streamPosition: streamPos, + eventIDToEvent[ev.EventID()] = types.StreamEvent{ + Event: ev, + StreamPosition: streamPos, + ExcludeFromSync: excludeFromSync, } } @@ -221,8 +249,8 @@ func (s *outputRoomEventsStatements) selectMaxEventID( func (s *outputRoomEventsStatements) insertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.Event, addState, removeState []string, - transactionID *api.TransactionID, -) (streamPos int64, err error) { + transactionID *api.TransactionID, excludeFromSync bool, +) (streamPos types.StreamPosition, err error) { var txnID *string var sessionID *int64 if transactionID != nil { @@ -251,16 +279,53 @@ func (s *outputRoomEventsStatements) insertEvent( pq.StringArray(removeState), sessionID, txnID, + excludeFromSync, ).Scan(&streamPos) return } -// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'. +// selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'. +// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude +// from sync. func (s *outputRoomEventsStatements) selectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, fromPos, toPos int64, limit int, -) ([]streamEvent, error) { - stmt := common.TxStmt(txn, s.selectRecentEventsStmt) + roomID string, fromPos, toPos types.StreamPosition, limit int, + chronologicalOrder bool, onlySyncEvents bool, +) ([]types.StreamEvent, error) { + var stmt *sql.Stmt + if onlySyncEvents { + stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt) + } else { + stmt = common.TxStmt(txn, s.selectRecentEventsStmt) + } + + rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + return events, nil +} + +// selectEarlyEvents returns the earliest events in the given room, starting +// from a given position, up to a maximum of 'limit'. +func (s *outputRoomEventsStatements) selectEarlyEvents( + ctx context.Context, txn *sql.Tx, + roomID string, fromPos, toPos types.StreamPosition, limit int, +) ([]types.StreamEvent, error) { + stmt := common.TxStmt(txn, s.selectEarlyEventsStmt) rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) if err != nil { return nil, err @@ -274,16 +339,16 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // necessarily the way the SQL query returns them, so a sort is necessary to // ensure the events are in the right order in the slice. sort.SliceStable(events, func(i int, j int) bool { - return events[i].streamPosition < events[j].streamPosition + return events[i].StreamPosition < events[j].StreamPosition }) return events, nil } -// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing -// from the database. +// selectEvents returns the events for the given event IDs. If an event is +// missing from the database, it will be omitted. func (s *outputRoomEventsStatements) selectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectEventsStmt) rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { @@ -293,17 +358,18 @@ func (s *outputRoomEventsStatements) selectEvents( return rowsToStreamEvents(rows) } -func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { - var result []streamEvent +func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { + var result []types.StreamEvent for rows.Next() { var ( - streamPos int64 - eventBytes []byte - sessionID *int64 - txnID *string - transactionID *api.TransactionID + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + sessionID *int64 + txnID *string + transactionID *api.TransactionID ) - if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil { + if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil { return nil, err } // TODO: Handle redacted events @@ -319,10 +385,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { } } - result = append(result, streamEvent{ - Event: ev, - streamPosition: streamPos, - transactionID: transactionID, + result = append(result, types.StreamEvent{ + Event: ev, + StreamPosition: streamPos, + TransactionID: transactionID, + ExcludeFromSync: excludeFromSync, }) } return result, nil diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go new file mode 100644 index 00000000..4a50b9a0 --- /dev/null +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -0,0 +1,188 @@ +// Copyright 2018 New Vector Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +const outputRoomEventsTopologySchema = ` +-- Stores output room events received from the roomserver. +CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology ( + -- The event ID for the event. + event_id TEXT PRIMARY KEY, + -- The place of the event in the room's topology. This can usually be determined + -- from the event's depth. + topological_position BIGINT NOT NULL, + -- The 'room_id' key for the event. + room_id TEXT NOT NULL +); +-- The topological order will be used in events selection and ordering +CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id); +` + +const insertEventInTopologySQL = "" + + "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" + + " VALUES ($1, $2, $3)" + + " ON CONFLICT DO NOTHING" + +const selectEventIDsInRangeASCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + + " ORDER BY topological_position ASC LIMIT $4" + +const selectEventIDsInRangeDESCSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" + + " ORDER BY topological_position DESC LIMIT $4" + +const selectPositionInTopologySQL = "" + + "SELECT topological_position FROM syncapi_output_room_events_topology" + + " WHERE event_id = $1" + +const selectMaxPositionInTopologySQL = "" + + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1" + +const selectEventIDsFromPositionSQL = "" + + "SELECT event_id FROM syncapi_output_room_events_topology" + + " WHERE room_id = $1 AND topological_position = $2" + +type outputRoomEventsTopologyStatements struct { + insertEventInTopologyStmt *sql.Stmt + selectEventIDsInRangeASCStmt *sql.Stmt + selectEventIDsInRangeDESCStmt *sql.Stmt + selectPositionInTopologyStmt *sql.Stmt + selectMaxPositionInTopologyStmt *sql.Stmt + selectEventIDsFromPositionStmt *sql.Stmt +} + +func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(outputRoomEventsTopologySchema) + if err != nil { + return + } + if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { + return + } + if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { + return + } + if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { + return + } + if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { + return + } + if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { + return + } + if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil { + return + } + return +} + +// insertEventInTopology inserts the given event in the room's topology, based +// on the event's depth. +func (s *outputRoomEventsTopologyStatements) insertEventInTopology( + ctx context.Context, event *gomatrixserverlib.Event, +) (err error) { + _, err = s.insertEventInTopologyStmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), + ) + return +} + +// selectEventIDsInRange selects the IDs of events which positions are within a +// given range in a given room's topological order. +// Returns an empty slice if no events match the given range. +func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange( + ctx context.Context, roomID string, fromPos, toPos types.StreamPosition, + limit int, chronologicalOrder bool, +) (eventIDs []string, err error) { + // Decide on the selection's order according to whether chronological order + // is requested or not. + var stmt *sql.Stmt + if chronologicalOrder { + stmt = s.selectEventIDsInRangeASCStmt + } else { + stmt = s.selectEventIDsInRangeDESCStmt + } + + // Query the event IDs. + rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) + if err == sql.ErrNoRows { + // If no event matched the request, return an empty slice. + return []string{}, nil + } else if err != nil { + return + } + + // Return the IDs. + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + + return +} + +// selectPositionInTopology returns the position of a given event in the +// topology of the room it belongs to. +func (s *outputRoomEventsTopologyStatements) selectPositionInTopology( + ctx context.Context, eventID string, +) (pos types.StreamPosition, err error) { + err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos) + return +} + +func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology( + ctx context.Context, roomID string, +) (pos types.StreamPosition, err error) { + err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos) + return +} + +// selectEventIDsFromPosition returns the IDs of all events that have a given +// position in the topology of a given room. +func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition( + ctx context.Context, roomID string, pos types.StreamPosition, +) (eventIDs []string, err error) { + // Query the event IDs. + rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos) + if err == sql.ErrNoRows { + // If no event matched the request, return an empty slice. + return []string{}, nil + } else if err != nil { + return + } + // Return the IDs. + var eventID string + for rows.Next() { + if err = rows.Scan(&eventID); err != nil { + return + } + eventIDs = append(eventIDs, eventID) + } + return +} diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 3a62d136..621aec95 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -20,7 +20,6 @@ import ( "database/sql" "encoding/json" "fmt" - "strconv" "time" "github.com/sirupsen/logrus" @@ -43,29 +42,24 @@ type stateDelta struct { membership string // The PDU stream position of the latest membership event for this user, if applicable. // Can be 0 if there is no membership event in this delta. - membershipPos int64 + membershipPos types.StreamPosition } -// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. -type streamEvent struct { - gomatrixserverlib.Event - streamPosition int64 - transactionID *api.TransactionID -} - -// SyncServerDatabase represents a sync server datasource which manages +// SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { db *sql.DB common.PartitionOffsetStatements - accountData accountDataStatements - events outputRoomEventsStatements - roomstate currentRoomStateStatements - invites inviteEventsStatements - typingCache *cache.TypingCache + accountData accountDataStatements + events outputRoomEventsStatements + roomstate currentRoomStateStatements + invites inviteEventsStatements + typingCache *cache.TypingCache + topology outputRoomEventsTopologyStatements + backwardExtremities backwardExtremitiesStatements } -// NewSyncServerDatabase creates a new sync server database +// NewSyncServerDatasource creates a new sync server database func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) { var d SyncServerDatasource var err error @@ -87,6 +81,12 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er if err := d.invites.prepare(d.db); err != nil { return nil, err } + if err := d.topology.prepare(d.db); err != nil { + return nil, err + } + if err := d.backwardExtremities.prepare(d.db); err != nil { + return nil, err + } d.typingCache = cache.NewTypingCache() return &d, nil } @@ -109,7 +109,46 @@ func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([ // We don't include a device here as we only include transaction IDs in // incremental syncs. - return streamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(nil, streamEvents), nil +} + +func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, ev *gomatrixserverlib.Event) error { + // If the event is already known as a backward extremity, don't consider + // it as such anymore now that we have it. + isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, ev.RoomID(), ev.EventID()) + if err != nil { + return err + } + if isBackwardExtremity { + if err = d.backwardExtremities.deleteBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + return err + } + } + + // Check if we have all of the event's previous events. If an event is + // missing, add it to the room's backward extremities. + prevEvents, err := d.events.selectEvents(ctx, nil, ev.PrevEventIDs()) + if err != nil { + return err + } + var found bool + for _, eID := range ev.PrevEventIDs() { + found = false + for _, prevEv := range prevEvents { + if eID == prevEv.EventID() { + found = true + } + } + + // If the event is missing, consider it a backward extremity. + if !found { + if err = d.backwardExtremities.insertsBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil { + return err + } + } + } + + return nil } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races @@ -120,16 +159,26 @@ func (d *SyncServerDatasource) WriteEvent( ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, - transactionID *api.TransactionID, -) (pduPosition int64, returnErr error) { + transactionID *api.TransactionID, excludeFromSync bool, +) (pduPosition types.StreamPosition, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID) + pos, err := d.events.insertEvent( + ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, + ) if err != nil { return err } pduPosition = pos + if err = d.topology.insertEventInTopology(ctx, ev); err != nil { + return err + } + + if err = d.handleBackwardExtremities(ctx, ev); err != nil { + return err + } + if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { // Nothing to do, the event may have just been a message event. return nil @@ -137,14 +186,15 @@ func (d *SyncServerDatasource) WriteEvent( return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition) }) - return + + return pduPosition, returnErr } func (d *SyncServerDatasource) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, addedEvents []gomatrixserverlib.Event, - pduPosition int64, + pduPosition types.StreamPosition, ) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { @@ -196,14 +246,141 @@ func (d *SyncServerDatasource) GetStateEventsForRoom( return } +// GetEventsInRange retrieves all of the events on a given ordering using the +// given extremities and limit. +func (d *SyncServerDatasource) GetEventsInRange( + ctx context.Context, + from, to *types.PaginationToken, + roomID string, limit int, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { + // If the pagination token's type is types.PaginationTokenTypeTopology, the + // events must be retrieved from the rooms' topology table rather than the + // table contaning the syncapi server's whole stream of events. + if from.Type == types.PaginationTokenTypeTopology { + // Determine the backward and forward limit, i.e. the upper and lower + // limits to the selection in the room's topology, from the direction. + var backwardLimit, forwardLimit types.StreamPosition + if backwardOrdering { + // Backward ordering is antichronological (latest event to oldest + // one). + backwardLimit = to.PDUPosition + forwardLimit = from.PDUPosition + } else { + // Forward ordering is chronological (oldest event to latest one). + backwardLimit = from.PDUPosition + forwardLimit = to.PDUPosition + } + + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.topology.selectEventIDsInRange( + ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering, + ) + if err != nil { + return + } + + // Retrieve the events' contents using their IDs. + events, err = d.events.selectEvents(ctx, nil, eIDs) + return + } + + // If the pagination token's type is types.PaginationTokenTypeStream, the + // events must be retrieved from the table contaning the syncapi server's + // whole stream of events. + + if backwardOrdering { + // When using backward ordering, we want the most recent events first. + if events, err = d.events.selectRecentEvents( + ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, + ); err != nil { + return + } + } else { + // When using forward ordering, we want the least recent events first. + if events, err = d.events.selectEarlyEvents( + ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, + ); err != nil { + return + } + } + + return +} + // SyncPosition returns the latest positions for syncing. -func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) { +func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { return d.syncPositionTx(ctx, nil) } +// BackwardExtremitiesForRoom returns the event IDs of all of the backward +// extremities we know of for a given room. +func (d *SyncServerDatasource) BackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (backwardExtremities []string, err error) { + return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID) +} + +// MaxTopologicalPosition returns the highest topological position for a given +// room. +func (d *SyncServerDatasource) MaxTopologicalPosition( + ctx context.Context, roomID string, +) (types.StreamPosition, error) { + return d.topology.selectMaxPositionInTopology(ctx, roomID) +} + +// EventsAtTopologicalPosition returns all of the events matching a given +// position in the topology of a given room. +func (d *SyncServerDatasource) EventsAtTopologicalPosition( + ctx context.Context, roomID string, pos types.StreamPosition, +) ([]types.StreamEvent, error) { + eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos) + if err != nil { + return nil, err + } + + return d.events.selectEvents(ctx, nil, eIDs) +} + +func (d *SyncServerDatasource) EventPositionInTopology( + ctx context.Context, eventID string, +) (types.StreamPosition, error) { + return d.topology.selectPositionInTopology(ctx, eventID) +} + +// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet. +func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { + return d.syncStreamPositionTx(ctx, nil) +} + +func (d *SyncServerDatasource) syncStreamPositionTx( + ctx context.Context, txn *sql.Tx, +) (types.StreamPosition, error) { + maxID, err := d.events.selectMaxEventID(ctx, txn) + if err != nil { + return 0, err + } + maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + if err != nil { + return 0, err + } + if maxAccountDataID > maxID { + maxID = maxAccountDataID + } + maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) + if err != nil { + return 0, err + } + if maxInviteID > maxID { + maxID = maxInviteID + } + return types.StreamPosition(maxID), nil +} + func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, -) (sp types.SyncPosition, err error) { +) (sp types.PaginationToken, err error) { maxEventID, err := d.events.selectMaxEventID(ctx, txn) if err != nil { @@ -223,10 +400,8 @@ func (d *SyncServerDatasource) syncPositionTx( if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp.PDUPosition = maxEventID - - sp.TypingPosition = d.typingCache.GetLatestSyncPosition() - + sp.PDUPosition = types.StreamPosition(maxEventID) + sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition()) return } @@ -235,7 +410,7 @@ func (d *SyncServerDatasource) syncPositionTx( func (d *SyncServerDatasource) addPDUDeltaToResponse( ctx context.Context, device authtypes.Device, - fromPos, toPos int64, + fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int, wantFullState bool, res *types.Response, @@ -287,7 +462,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse( // addTypingDeltaToResponse adds all typing notifications to a sync response // since the specified position. func (d *SyncServerDatasource) addTypingDeltaToResponse( - since int64, + since types.PaginationToken, joinedRoomIDs []string, res *types.Response, ) error { @@ -296,7 +471,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( var err error for _, roomID := range joinedRoomIDs { if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter( - roomID, since, + roomID, int64(since.EDUTypingPosition), ); updated { ev := gomatrixserverlib.ClientEvent{ Type: gomatrixserverlib.MTyping, @@ -321,14 +496,14 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // the positions of that type are not equal in fromPos and toPos. func (d *SyncServerDatasource) addEDUDeltaToResponse( - fromPos, toPos types.SyncPosition, + fromPos, toPos types.PaginationToken, joinedRoomIDs []string, res *types.Response, ) (err error) { - if fromPos.TypingPosition != toPos.TypingPosition { + if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { err = d.addTypingDeltaToResponse( - fromPos.TypingPosition, joinedRoomIDs, res, + fromPos, joinedRoomIDs, res, ) } @@ -343,7 +518,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) IncrementalSync( ctx context.Context, device authtypes.Device, - fromPos, toPos types.SyncPosition, + fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool, ) (*types.Response, error) { @@ -383,7 +558,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( numRecentEventsPerRoom int, ) ( res *types.Response, - toPos types.SyncPosition, + toPos types.PaginationToken, joinedRoomIDs []string, err error, ) { @@ -423,27 +598,37 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( } // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - var recentStreamEvents []streamEvent + var recentStreamEvents []types.StreamEvent recentStreamEvents, err = d.events.selectRecentEvents( - ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, + ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, + numRecentEventsPerRoom, true, true, + //ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom, ) if err != nil { return } + // Retrieve the backward topology position, i.e. the position of the + // oldest event in the room's topology. + var backwardTopologyPos types.StreamPosition + backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) + if err != nil { + return nil, types.PaginationToken{}, []string{}, err + } + if backwardTopologyPos-1 <= 0 { + backwardTopologyPos = types.StreamPosition(1) + } else { + backwardTopologyPos = backwardTopologyPos - 1 + } + // We don't include a device here as we don't need to send down // transaction IDs for complete syncs - recentEvents := streamEventsToEvents(nil, recentStreamEvents) - + recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() - if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 { - // Use the short form of batch token for prev_batch - jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) - } else { - // Use the short form of batch token for prev_batch - jr.Timeline.PrevBatch = "1" - } + jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + ).String() jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = true jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync) @@ -471,7 +656,7 @@ func (d *SyncServerDatasource) CompleteSync( // Use a zero value SyncPosition for fromPos so all EDU states are added. err = d.addEDUDeltaToResponse( - types.SyncPosition{}, toPos, joinedRoomIDs, res, + types.PaginationToken{}, toPos, joinedRoomIDs, res, ) if err != nil { return nil, err @@ -496,7 +681,7 @@ var txReadOnlySnapshot = sql.TxOptions{ // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error func (d *SyncServerDatasource) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos int64, + ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart, ) (map[string][]string, error) { return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) @@ -510,7 +695,7 @@ func (d *SyncServerDatasource) GetAccountDataInRange( // Returns an error if there was an issue with the upsert func (d *SyncServerDatasource) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, -) (int64, error) { +) (types.StreamPosition, error) { return d.accountData.insertAccountData(ctx, userID, roomID, dataType) } @@ -519,7 +704,7 @@ func (d *SyncServerDatasource) UpsertAccountData( // Returns an error if there was a problem communicating with the database. func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.Event, -) (int64, error) { +) (types.StreamPosition, error) { return d.invites.insertInviteEvent(ctx, inviteEvent) } @@ -542,26 +727,26 @@ func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallback // Returns the newly calculated sync position for typing notifications. func (d *SyncServerDatasource) AddTypingUser( userID, roomID string, expireTime *time.Time, -) int64 { - return d.typingCache.AddTypingUser(userID, roomID, expireTime) +) types.StreamPosition { + return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime)) } // RemoveTypingUser removes a typing user from the typing cache. // Returns the newly calculated sync position for typing notifications. func (d *SyncServerDatasource) RemoveTypingUser( userID, roomID string, -) int64 { - return d.typingCache.RemoveUser(userID, roomID) +) types.StreamPosition { + return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID)) } func (d *SyncServerDatasource) addInvitesToResponse( ctx context.Context, txn *sql.Tx, userID string, - fromPos, toPos int64, + fromPos, toPos types.StreamPosition, res *types.Response, ) error { invites, err := d.invites.selectInviteEventsInRange( - ctx, txn, userID, int64(fromPos), int64(toPos), + ctx, txn, userID, fromPos, toPos, ) if err != nil { return err @@ -577,12 +762,32 @@ func (d *SyncServerDatasource) addInvitesToResponse( return nil } +// Retrieve the backward topology position, i.e. the position of the +// oldest event in the room's topology. +func (d *SyncServerDatasource) getBackwardTopologyPos( + ctx context.Context, + events []types.StreamEvent, +) (pos types.StreamPosition, err error) { + if len(events) > 0 { + pos, err = d.topology.selectPositionInTopology(ctx, events[0].EventID()) + if err != nil { + return + } + } + if pos-1 <= 0 { + pos = types.StreamPosition(1) + } else { + pos = pos - 1 + } + return +} + // addRoomDeltaToResponse adds a room state delta to a sync response func (d *SyncServerDatasource) addRoomDeltaToResponse( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos int64, + fromPos, toPos types.StreamPosition, delta stateDelta, numRecentEventsPerRoom int, res *types.Response, @@ -598,38 +803,28 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( endPos = delta.membershipPos } recentStreamEvents, err := d.events.selectRecentEvents( - ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom, + ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos), + numRecentEventsPerRoom, true, true, ) if err != nil { return err } - recentEvents := streamEventsToEvents(device, recentStreamEvents) + recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - var prevPDUPos int64 - - if len(recentEvents) == 0 { - if len(delta.stateEvents) == 0 { - // Don't bother appending empty room entries - return nil - } - - // If full_state=true and since is already up to date, then we'll have - // state events but no recent events. - prevPDUPos = toPos - 1 - } else { - prevPDUPos = recentStreamEvents[0].streamPosition - 1 - } - - if prevPDUPos <= 0 { - prevPDUPos = 1 + var backwardTopologyPos types.StreamPosition + backwardTopologyPos, err = d.getBackwardTopologyPos(ctx, recentStreamEvents) + if err != nil { + return err } switch delta.membership { case gomatrixserverlib.Join: jr := types.NewJoinResponse() - // Use the short form of batch token for prev_batch - jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) + + jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + ).String() jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -640,8 +835,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // TODO: recentEvents may contain events that this user is not allowed to see because they are // no longer in the room. lr := types.NewLeaveResponse() - // Use the short form of batch token for prev_batch - lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10) + lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( + types.PaginationTokenTypeTopology, backwardTopologyPos, 0, + ).String() lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) @@ -656,9 +852,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( func (d *SyncServerDatasource) fetchStateEvents( ctx context.Context, txn *sql.Tx, roomIDToEventIDSet map[string]map[string]bool, - eventIDToEvent map[string]streamEvent, -) (map[string][]streamEvent, error) { - stateBetween := make(map[string][]streamEvent) + eventIDToEvent map[string]types.StreamEvent, +) (map[string][]types.StreamEvent, error) { + stateBetween := make(map[string][]types.StreamEvent) missingEvents := make(map[string][]string) for roomID, ids := range roomIDToEventIDSet { events := stateBetween[roomID] @@ -700,7 +896,7 @@ func (d *SyncServerDatasource) fetchStateEvents( func (d *SyncServerDatasource) fetchMissingStateEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. events, err := d.events.selectEvents(ctx, txn, eventIDs) @@ -743,7 +939,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents( // A list of joined room IDs is also returned in case the caller needs it. func (d *SyncServerDatasource) getStateDeltas( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos int64, userID string, + fromPos, toPos types.StreamPosition, userID string, stateFilterPart *gomatrix.FilterPart, ) ([]stateDelta, []string, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 @@ -776,7 +972,7 @@ func (d *SyncServerDatasource) getStateDeltas( if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership == gomatrixserverlib.Join { // send full room state down instead of a delta - var s []streamEvent + var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart) if err != nil { return nil, nil, err @@ -787,8 +983,8 @@ func (d *SyncServerDatasource) getStateDeltas( deltas = append(deltas, stateDelta{ membership: membership, - membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(device, stateStreamEvents), + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) break @@ -804,7 +1000,7 @@ func (d *SyncServerDatasource) getStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, stateDelta{ membership: gomatrixserverlib.Join, - stateEvents: streamEventsToEvents(device, state[joinedRoomID]), + stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), roomID: joinedRoomID, }) } @@ -818,7 +1014,7 @@ func (d *SyncServerDatasource) getStateDeltas( // updates for other rooms. func (d *SyncServerDatasource) getStateDeltasForFullStateSync( ctx context.Context, device *authtypes.Device, txn *sql.Tx, - fromPos, toPos int64, userID string, + fromPos, toPos types.StreamPosition, userID string, stateFilterPart *gomatrix.FilterPart, ) ([]stateDelta, []string, error) { joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) @@ -837,7 +1033,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( } deltas = append(deltas, stateDelta{ membership: gomatrixserverlib.Join, - stateEvents: streamEventsToEvents(device, s), + stateEvents: d.StreamEventsToEvents(device, s), roomID: joinedRoomID, }) } @@ -858,8 +1054,8 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. deltas = append(deltas, stateDelta{ membership: membership, - membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(device, stateStreamEvents), + membershipPos: ev.StreamPosition, + stateEvents: d.StreamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) } @@ -875,29 +1071,29 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( func (d *SyncServerDatasource) currentStateStreamEventsForRoom( ctx context.Context, txn *sql.Tx, roomID string, stateFilterPart *gomatrix.FilterPart, -) ([]streamEvent, error) { +) ([]types.StreamEvent, error) { allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart) if err != nil { return nil, err } - s := make([]streamEvent, len(allState)) + s := make([]types.StreamEvent, len(allState)) for i := 0; i < len(s); i++ { - s[i] = streamEvent{Event: allState[i], streamPosition: 0} + s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0} } return s, nil } -// streamEventsToEvents converts streamEvent to Event. If device is non-nil and +// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. -func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event { +func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event { out := make([]gomatrixserverlib.Event, len(in)) for i := 0; i < len(in); i++ { out[i] = in[i].Event - if device != nil && in[i].transactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID { + if device != nil && in[i].TransactionID != nil { + if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( - "transaction_id", in[i].transactionID.TransactionID, + "transaction_id", in[i].TransactionID.TransactionID, ) if err != nil { logrus.WithFields(logrus.Fields{ diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go index 5db4b3a1..4e8a2c83 100644 --- a/syncapi/storage/storage.go +++ b/syncapi/storage/storage.go @@ -33,19 +33,26 @@ type Database interface { common.PartitionStorer AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) - WriteEvent(ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID) (pduPosition int64, returnErr error) + WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error) - SyncPosition(ctx context.Context) (types.SyncPosition, error) - IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.SyncPosition, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + SyncPosition(ctx context.Context) (types.PaginationToken, error) + IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) - GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos int64, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error) - UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (int64, error) - AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (int64, error) + GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error) + UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error) + AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (types.StreamPosition, error) RetireInviteEvent(ctx context.Context, inviteEventID string) error SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) - AddTypingUser(userID, roomID string, expireTime *time.Time) int64 - RemoveTypingUser(userID, roomID string) int64 + AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition + RemoveTypingUser(userID, roomID string) types.StreamPosition + GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) + EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error) + EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error) + BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error) + MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error) + StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event + SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) } // NewPublicRoomsServerDatabase opens a database connection. diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 548a17ac..aaee49d3 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -36,7 +36,7 @@ type Notifier struct { // Protects currPos and userStreams. streamLock *sync.Mutex // The latest sync position - currPos types.SyncPosition + currPos types.PaginationToken // A map of user_id => UserStream which can be used to wake a given user's /sync request. userStreams map[string]*UserStream // The last time we cleaned out stale entries from the userStreams map @@ -46,7 +46,7 @@ type Notifier struct { // NewNotifier creates a new notifier set to the given sync position. // In order for this to be of any use, the Notifier needs to be told all rooms and // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). -func NewNotifier(pos types.SyncPosition) *Notifier { +func NewNotifier(pos types.PaginationToken) *Notifier { return &Notifier{ currPos: pos, roomIDToJoinedUsers: make(map[string]userIDSet), @@ -68,7 +68,7 @@ func NewNotifier(pos types.SyncPosition) *Notifier { // event type it handles, leaving other fields as 0. func (n *Notifier) OnNewEvent( ev *gomatrixserverlib.Event, roomID string, userIDs []string, - posUpdate types.SyncPosition, + posUpdate types.PaginationToken, ) { // update the current position then notify relevant /sync streams. // This needs to be done PRIOR to waking up users as they will read this value. @@ -151,7 +151,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error { } // CurrentPosition returns the current sync position -func (n *Notifier) CurrentPosition() types.SyncPosition { +func (n *Notifier) CurrentPosition() types.PaginationToken { n.streamLock.Lock() defer n.streamLock.Unlock() @@ -173,7 +173,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) { } } -func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) { +func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) { for _, userID := range userIDs { stream := n.fetchUserStream(userID, false) if stream != nil { diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go index 808e07cc..02da0f7e 100644 --- a/syncapi/sync/notifier_test.go +++ b/syncapi/sync/notifier_test.go @@ -32,11 +32,11 @@ var ( randomMessageEvent gomatrixserverlib.Event aliceInviteBobEvent gomatrixserverlib.Event bobLeaveEvent gomatrixserverlib.Event - syncPositionVeryOld types.SyncPosition - syncPositionBefore types.SyncPosition - syncPositionAfter types.SyncPosition - syncPositionNewEDU types.SyncPosition - syncPositionAfter2 types.SyncPosition + syncPositionVeryOld types.PaginationToken + syncPositionBefore types.PaginationToken + syncPositionAfter types.PaginationToken + syncPositionNewEDU types.PaginationToken + syncPositionAfter2 types.PaginationToken ) var ( @@ -46,9 +46,9 @@ var ( ) func init() { - baseSyncPos := types.SyncPosition{ - PDUPosition: 0, - TypingPosition: 0, + baseSyncPos := types.PaginationToken{ + PDUPosition: 0, + EDUTypingPosition: 0, } syncPositionVeryOld = baseSyncPos @@ -61,7 +61,7 @@ func init() { syncPositionAfter.PDUPosition = 12 syncPositionNewEDU = syncPositionAfter - syncPositionNewEDU.TypingPosition = 1 + syncPositionNewEDU.EDUTypingPosition = 1 syncPositionAfter2 = baseSyncPos syncPositionAfter2.PDUPosition = 13 @@ -119,7 +119,7 @@ func TestImmediateNotification(t *testing.T) { t.Fatalf("TestImmediateNotification error: %s", err) } if pos != syncPositionBefore { - t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos) + t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos) } } @@ -138,7 +138,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) { t.Errorf("TestNewEventAndJoinedToRoom error: %s", err) } if pos != syncPositionAfter { - t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos) + t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos) } wg.Done() }() @@ -166,7 +166,7 @@ func TestNewInviteEventForUser(t *testing.T) { t.Errorf("TestNewInviteEventForUser error: %s", err) } if pos != syncPositionAfter { - t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos) + t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos) } wg.Done() }() @@ -194,7 +194,7 @@ func TestEDUWakeup(t *testing.T) { t.Errorf("TestNewInviteEventForUser error: %s", err) } if pos != syncPositionNewEDU { - t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos) + t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos) } wg.Done() }() @@ -222,7 +222,7 @@ func TestMultipleRequestWakeup(t *testing.T) { t.Errorf("TestMultipleRequestWakeup error: %s", err) } if pos != syncPositionAfter { - t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos) + t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos) } wg.Done() } @@ -262,7 +262,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) } if pos != syncPositionAfter { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos) + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos) } leaveWG.Done() }() @@ -281,7 +281,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err) } if pos != syncPositionAfter2 { - t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos) + t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos) } aliceWG.Done() }() @@ -305,14 +305,14 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) { time.Sleep(1 * time.Millisecond) } -func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) { +func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) { listener := n.GetListener(req) defer listener.Close() select { case <-time.After(5 * time.Second): - return types.SyncPosition{}, fmt.Errorf( - "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since, + return types.PaginationToken{}, fmt.Errorf( + "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, ) case <-listener.GetNotifyChannel(*req.since): p := listener.GetSyncPosition() @@ -337,7 +337,7 @@ func lockedFetchUserStream(n *Notifier, userID string) *UserStream { return n.fetchUserStream(userID, true) } -func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest { +func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest { return syncRequest{ device: authtypes.Device{UserID: userID}, timeout: 1 * time.Minute, diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index a5d2f60f..f2e199d2 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -16,10 +16,8 @@ package sync import ( "context" - "errors" "net/http" "strconv" - "strings" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -38,7 +36,7 @@ type syncRequest struct { device authtypes.Device limit int timeout time.Duration - since *types.SyncPosition // nil means that no since token was supplied + since *types.PaginationToken // nil means that no since token was supplied wantFullState bool log *log.Entry } @@ -47,7 +45,7 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" - since, err := getSyncStreamPosition(req.URL.Query().Get("since")) + since, err := getPaginationToken(req.URL.Query().Get("since")) if err != nil { return nil, err } @@ -75,41 +73,14 @@ func getTimeout(timeoutMS string) time.Duration { } // getSyncStreamPosition tries to parse a 'since' token taken from the API to a -// types.SyncPosition. If the string is empty then (nil, nil) is returned. +// types.PaginationToken. If the string is empty then (nil, nil) is returned. // There are two forms of tokens: The full length form containing all PDU and EDU // positions separated by "_", and the short form containing only the PDU // position. Short form can be used for, e.g., `prev_batch` tokens. -func getSyncStreamPosition(since string) (*types.SyncPosition, error) { +func getPaginationToken(since string) (*types.PaginationToken, error) { if since == "" { return nil, nil } - posStrings := strings.Split(since, "_") - if len(posStrings) != 2 && len(posStrings) != 1 { - // A token can either be full length or short (PDU-only). - return nil, errors.New("malformed batch token") - } - - positions := make([]int64, len(posStrings)) - for i, posString := range posStrings { - pos, err := strconv.ParseInt(posString, 10, 64) - if err != nil { - return nil, err - } - positions[i] = pos - } - - if len(positions) == 2 { - // Full length token; construct SyncPosition with every entry in - // `positions`. These entries must have the same order with the fields - // in struct SyncPosition, so we disable the govet check below. - return &types.SyncPosition{ //nolint:govet - positions[0], positions[1], - }, nil - } else { - // Token with PDU position only - return &types.SyncPosition{ - PDUPosition: positions[0], - }, nil - } + return types.NewPaginationTokenFromString(since) } diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index d75f07e6..5a3ae880 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -130,7 +130,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype } } -func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) { +func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) @@ -143,7 +143,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncP } accountDataFilter := gomatrix.DefaultFilterPart() // TODO: use filter provided in req instead - res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) + res, err = rp.appendAccountData(res, req.device.UserID, req, int64(latestPos.PDUPosition), &accountDataFilter) return } @@ -183,7 +183,11 @@ func (rp *RequestPool) appendAccountData( } // Sync is not initial, get all account data since the latest sync - dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos, accountDataFilter) + dataTypes, err := rp.db.GetAccountDataInRange( + req.ctx, userID, + types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos), + accountDataFilter, + ) if err != nil { return nil, err } diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go index beb10e48..6eef8644 100644 --- a/syncapi/sync/userstream.go +++ b/syncapi/sync/userstream.go @@ -35,7 +35,7 @@ type UserStream struct { // Closed when there is an update. signalChannel chan struct{} // The last sync position that there may have been an update for the user - pos types.SyncPosition + pos types.PaginationToken // The last time when we had some listeners waiting timeOfLastChannel time.Time // The number of listeners waiting @@ -51,7 +51,7 @@ type UserStreamListener struct { } // NewUserStream creates a new user stream -func NewUserStream(userID string, currPos types.SyncPosition) *UserStream { +func NewUserStream(userID string, currPos types.PaginationToken) *UserStream { return &UserStream{ UserID: userID, timeOfLastChannel: time.Now(), @@ -85,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener { } // Broadcast a new sync position for this user. -func (s *UserStream) Broadcast(pos types.SyncPosition) { +func (s *UserStream) Broadcast(pos types.PaginationToken) { s.lock.Lock() defer s.lock.Unlock() @@ -120,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time { // GetStreamPosition returns last sync position which the UserStream was // notified about -func (s *UserStreamListener) GetSyncPosition() types.SyncPosition { +func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() @@ -132,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.SyncPosition { // sincePos specifies from which point we want to be notified about. If there // has already been an update after sincePos we'll return a closed channel // immediately. -func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} { +func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} { s.userStream.lock.Lock() defer s.userStream.lock.Unlock() diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 4738feea..ecf532ca 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -21,7 +21,9 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/common/basecomponent" + "github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/syncapi/consumers" @@ -37,6 +39,8 @@ func SetupSyncAPIComponent( deviceDB *devices.Database, accountsDB *accounts.Database, queryAPI api.RoomserverQueryAPI, + federation *gomatrixserverlib.FederationClient, + cfg *config.Dendrite, ) { syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI)) if err != nil { @@ -77,5 +81,5 @@ func SetupSyncAPIComponent( logrus.WithError(err).Panicf("failed to start typing server consumer") } - routing.Setup(base.APIMux, requestPool, syncDB, deviceDB) + routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, federation, queryAPI, cfg) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index af7ec865..c25a38cd 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -16,45 +16,144 @@ package types import ( "encoding/json" + "errors" + "fmt" "strconv" + "strings" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" ) -// SyncPosition contains the PDU and EDU stream sync positions for a client. -type SyncPosition struct { - // PDUPosition is the stream position for PDUs the client is at. - PDUPosition int64 - // TypingPosition is the client's position for typing notifications. - TypingPosition int64 +var ( + // ErrInvalidPaginationTokenType is returned when an attempt at creating a + // new instance of PaginationToken with an invalid type (i.e. neither "s" + // nor "t"). + ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)") + // ErrInvalidPaginationTokenLen is returned when the pagination token is an + // invalid length + ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length") +) + +// StreamPosition represents the offset in the sync stream a client is at. +type StreamPosition int64 + +// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. +type StreamEvent struct { + gomatrixserverlib.Event + StreamPosition StreamPosition + TransactionID *api.TransactionID + ExcludeFromSync bool } -// String implements the Stringer interface. -func (sp SyncPosition) String() string { - return strconv.FormatInt(sp.PDUPosition, 10) + "_" + - strconv.FormatInt(sp.TypingPosition, 10) +// PaginationTokenType represents the type of a pagination token. +// It can be either "s" (representing a position in the whole stream of events) +// or "t" (representing a position in a room's topology/depth). +type PaginationTokenType string + +const ( + // PaginationTokenTypeStream represents a position in the server's whole + // stream of events + PaginationTokenTypeStream PaginationTokenType = "s" + // PaginationTokenTypeTopology represents a position in a room's topology. + PaginationTokenTypeTopology PaginationTokenType = "t" +) + +// PaginationToken represents a pagination token, used for interactions with +// /sync or /messages, for example. +type PaginationToken struct { + //Position StreamPosition + Type PaginationTokenType + PDUPosition StreamPosition + EDUTypingPosition StreamPosition } -// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition. -func (sp SyncPosition) IsAfter(other SyncPosition) bool { - return sp.PDUPosition > other.PDUPosition || - sp.TypingPosition > other.TypingPosition +// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x" +// represents the type of a pagination token and "yyyy..." the token itself, and +// parses it in order to create a new instance of PaginationToken. Returns an +// error if the token couldn't be parsed into an int64, or if the token type +// isn't a known type (returns ErrInvalidPaginationTokenType in the latter +// case). +func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) { + if len(s) == 0 { + return nil, ErrInvalidPaginationTokenLen + } + + token = new(PaginationToken) + var positions []string + + switch t := PaginationTokenType(s[:1]); t { + case PaginationTokenTypeStream, PaginationTokenTypeTopology: + token.Type = t + positions = strings.Split(s[1:], "_") + default: + token.Type = PaginationTokenTypeStream + positions = strings.Split(s, "_") + } + + // Try to get the PDU position. + if len(positions) >= 1 { + if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil { + return nil, err + } else if pduPos < 0 { + return nil, errors.New("negative PDU position not allowed") + } else { + token.PDUPosition = StreamPosition(pduPos) + } + } + + // Try to get the typing position. + if len(positions) >= 2 { + if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil { + return nil, err + } else if typPos < 0 { + return nil, errors.New("negative EDU typing position not allowed") + } else { + token.EDUTypingPosition = StreamPosition(typPos) + } + } + + return +} + +// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a +// StreamPosition and returns an instance of PaginationToken. +func NewPaginationTokenFromTypeAndPosition( + t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition, +) (p *PaginationToken) { + return &PaginationToken{ + Type: t, + PDUPosition: pdupos, + EDUTypingPosition: typpos, + } } -// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition. -// If the latter SyncPosition contains a field that is not 0, it is considered an update, -// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called. -func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition { - ret := sp +// String translates a PaginationToken to a string of the "xyyyy..." (see +// NewPaginationToken to know what it represents). +func (p *PaginationToken) String() string { + return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition) +} + +// WithUpdates returns a copy of the PaginationToken with updates applied from another PaginationToken. +// If the latter PaginationToken contains a field that is not 0, it is considered an update, +// and its value will replace the corresponding value in the PaginationToken on which WithUpdates is called. +func (pt *PaginationToken) WithUpdates(other PaginationToken) PaginationToken { + ret := *pt if other.PDUPosition != 0 { ret.PDUPosition = other.PDUPosition } - if other.TypingPosition != 0 { - ret.TypingPosition = other.TypingPosition + if other.EDUTypingPosition != 0 { + ret.EDUTypingPosition = other.EDUTypingPosition } return ret } +// IsAfter returns whether one PaginationToken refers to states newer than another PaginationToken. +func (sp *PaginationToken) IsAfter(other PaginationToken) bool { + return sp.PDUPosition > other.PDUPosition || + sp.EDUTypingPosition > other.EDUTypingPosition +} + // PrevEventRef represents a reference to a previous event in a state event upgrade type PrevEventRef struct { PrevContent json.RawMessage `json:"prev_content"` @@ -79,9 +178,9 @@ type Response struct { } // NewResponse creates an empty response with initialised maps. -func NewResponse(pos SyncPosition) *Response { +func NewResponse(token PaginationToken) *Response { res := Response{ - NextBatch: pos.String(), + NextBatch: token.String(), } // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. @@ -96,6 +195,14 @@ func NewResponse(pos SyncPosition) *Response { res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) + // Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume + // we'll always return a stream token. + res.NextBatch = NewPaginationTokenFromTypeAndPosition( + PaginationTokenTypeStream, + StreamPosition(token.PDUPosition), + StreamPosition(token.EDUTypingPosition), + ).String() + return &res } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go new file mode 100644 index 00000000..f4c84e0d --- /dev/null +++ b/syncapi/types/types_test.go @@ -0,0 +1,52 @@ +package types + +import "testing" + +func TestNewPaginationTokenFromString(t *testing.T) { + shouldPass := map[string]PaginationToken{ + "2": PaginationToken{ + Type: PaginationTokenTypeStream, + PDUPosition: 2, + }, + "s4": PaginationToken{ + Type: PaginationTokenTypeStream, + PDUPosition: 4, + }, + "s3_1": PaginationToken{ + Type: PaginationTokenTypeStream, + PDUPosition: 3, + EDUTypingPosition: 1, + }, + "t3_1_4": PaginationToken{ + Type: PaginationTokenTypeTopology, + PDUPosition: 3, + EDUTypingPosition: 1, + }, + } + + shouldFail := []string{ + "", + "s_1", + "s_", + "a3_4", + "b", + "b-1", + "-4", + } + + for test, expected := range shouldPass { + result, err := NewPaginationTokenFromString(test) + if err != nil { + t.Error(err) + } + if *result != expected { + t.Errorf("expected %v but got %v", expected.String(), result.String()) + } + } + + for _, test := range shouldFail { + if _, err := NewPaginationTokenFromString(test); err == nil { + t.Errorf("input '%v' should have errored but didn't", test) + } + } +} |