aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-01-23 17:51:10 +0000
committerGitHub <noreply@github.com>2020-01-23 17:51:10 +0000
commit49f760a30b6496c8b3e1ceaf98dccc4376f6605d (patch)
treeb00d3fc17144cc83df1e5c7b8d1080ca19041243 /syncapi
parent43ecf8d1f909f4eb71bba93f6e7a57db59ec5941 (diff)
CS API: Support for /messages, fixes for /sync (#847)
* Merge forward * Tidy up a bit * TODO: What to do with NextBatch here? * Replace SyncPosition with PaginationToken throughout syncapi * Fix PaginationTokens * Fix lint errors * Add a couple of missing functions into the syncapi external storage interface * Some updates based on review comments from @babolivier * Some updates based on review comments from @babolivier * argh whitespacing * Fix opentracing span * Remove dead code * Don't overshadow err (fix lint issue) * Handle extremities after inserting event into topology * Try insert event topology as ON CONFLICT DO NOTHING * Prevent OOB error in addRoomDeltaToResponse * Thwarted by gocyclo again * Fix NewPaginationTokenFromString, define unit test for it * Update pagination token test * Update sytest-whitelist * Hopefully fix some of the sync batch tokens * Remove extraneous sync position func * Revert to topology tokens in addRoomDeltaToResponse etc * Fix typo * Remove prevPDUPos as dead now that backwardTopologyPos is used instead * Fix selectEventsWithEventIDsSQL * Update sytest-blacklist * Update sytest-whitelist
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/consumers/clientapi.go2
-rw-r--r--syncapi/consumers/roomserver.go5
-rw-r--r--syncapi/consumers/typingserver.go11
-rw-r--r--syncapi/routing/messages.go482
-rw-r--r--syncapi/routing/routing.go18
-rw-r--r--syncapi/storage/postgres/account_data_table.go5
-rw-r--r--syncapi/storage/postgres/backward_extremities_table.go118
-rw-r--r--syncapi/storage/postgres/current_room_state_table.go9
-rw-r--r--syncapi/storage/postgres/invites_table.go5
-rw-r--r--syncapi/storage/postgres/output_room_events_table.go197
-rw-r--r--syncapi/storage/postgres/output_room_events_topology_table.go188
-rw-r--r--syncapi/storage/postgres/syncserver.go400
-rw-r--r--syncapi/storage/storage.go23
-rw-r--r--syncapi/sync/notifier.go10
-rw-r--r--syncapi/sync/notifier_test.go40
-rw-r--r--syncapi/sync/request.go39
-rw-r--r--syncapi/sync/requestpool.go10
-rw-r--r--syncapi/sync/userstream.go10
-rw-r--r--syncapi/syncapi.go6
-rw-r--r--syncapi/types/types.go153
-rw-r--r--syncapi/types/types_test.go52
21 files changed, 1502 insertions, 281 deletions
diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go
index ed39cd2d..17f2c522 100644
--- a/syncapi/consumers/clientapi.go
+++ b/syncapi/consumers/clientapi.go
@@ -90,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data")
}
- s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.SyncPosition{PDUPosition: pduPos})
+ s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos})
return nil
}
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index cde2f508..ba1b7dc5 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -133,6 +133,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
msg.AddsStateEventIDs,
msg.RemovesStateEventIDs,
msg.TransactionID,
+ false,
)
if err != nil {
// panic rather than continue with an inconsistent database
@@ -144,7 +145,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure")
return nil
}
- s.notifier.OnNewEvent(&ev, "", nil, types.SyncPosition{PDUPosition: pduPos})
+ s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos})
return nil
}
@@ -161,7 +162,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure")
return nil
}
- s.notifier.OnNewEvent(&msg.Event, "", nil, types.SyncPosition{PDUPosition: pduPos})
+ s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos})
return nil
}
diff --git a/syncapi/consumers/typingserver.go b/syncapi/consumers/typingserver.go
index 392f7987..36925441 100644
--- a/syncapi/consumers/typingserver.go
+++ b/syncapi/consumers/typingserver.go
@@ -63,7 +63,12 @@ func NewOutputTypingEventConsumer(
// Start consuming from typing api
func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
- s.notifier.OnNewEvent(nil, roomID, nil, types.SyncPosition{TypingPosition: latestSyncPosition})
+ s.notifier.OnNewEvent(
+ nil, roomID, nil,
+ types.PaginationToken{
+ EDUTypingPosition: types.StreamPosition(latestSyncPosition),
+ },
+ )
})
return s.typingConsumer.Start()
@@ -83,7 +88,7 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
"typing": output.Event.Typing,
}).Debug("received data from typing server")
- var typingPos int64
+ var typingPos types.StreamPosition
typingEvent := output.Event
if typingEvent.Typing {
typingPos = s.db.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime)
@@ -91,6 +96,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
}
- s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.SyncPosition{TypingPosition: typingPos})
+ s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos})
return nil
}
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
new file mode 100644
index 00000000..26f48ca4
--- /dev/null
+++ b/syncapi/routing/messages.go
@@ -0,0 +1,482 @@
+// Copyright 2018 New Vector Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "context"
+ "net/http"
+ "sort"
+ "strconv"
+
+ "github.com/matrix-org/dendrite/clientapi/httputil"
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/common/config"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ log "github.com/sirupsen/logrus"
+)
+
+type messagesReq struct {
+ ctx context.Context
+ db storage.Database
+ queryAPI api.RoomserverQueryAPI
+ federation *gomatrixserverlib.FederationClient
+ cfg *config.Dendrite
+ roomID string
+ from *types.PaginationToken
+ to *types.PaginationToken
+ wasToProvided bool
+ limit int
+ backwardOrdering bool
+}
+
+type messagesResp struct {
+ Start string `json:"start"`
+ End string `json:"end"`
+ Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
+}
+
+const defaultMessagesLimit = 10
+
+// OnIncomingMessagesRequest implements the /messages endpoint from the
+// client-server API.
+// See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages
+func OnIncomingMessagesRequest(
+ req *http.Request, db storage.Database, roomID string,
+ federation *gomatrixserverlib.FederationClient,
+ queryAPI api.RoomserverQueryAPI,
+ cfg *config.Dendrite,
+) util.JSONResponse {
+ var err error
+
+ // Extract parameters from the request's URL.
+ // Pagination tokens.
+ from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from"))
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()),
+ }
+ }
+
+ // Direction to return events from.
+ dir := req.URL.Query().Get("dir")
+ if dir != "b" && dir != "f" {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"),
+ }
+ }
+ // A boolean is easier to handle in this case, especially since dir is sure
+ // to have one of the two accepted values (so dir == "f" <=> !backwardOrdering).
+ backwardOrdering := (dir == "b")
+
+ // Pagination tokens. To is optional, and its default value depends on the
+ // direction ("b" or "f").
+ var to *types.PaginationToken
+ wasToProvided := true
+ if s := req.URL.Query().Get("to"); len(s) > 0 {
+ to, err = types.NewPaginationTokenFromString(s)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
+ }
+ }
+ } else {
+ // If "to" isn't provided, it defaults to either the earliest stream
+ // position (if we're going backward) or to the latest one (if we're
+ // going forward).
+ to, err = setToDefault(req.Context(), db, backwardOrdering, roomID)
+ if err != nil {
+ return httputil.LogThenError(req, err)
+ }
+ wasToProvided = false
+ }
+
+ // Maximum number of events to return; defaults to 10.
+ limit := defaultMessagesLimit
+ if len(req.URL.Query().Get("limit")) > 0 {
+ limit, err = strconv.Atoi(req.URL.Query().Get("limit"))
+
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("limit could not be parsed into an integer: " + err.Error()),
+ }
+ }
+ }
+ // TODO: Implement filtering (#587)
+
+ // Check the room ID's format.
+ if _, _, err = gomatrixserverlib.SplitID('!', roomID); err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.MissingArgument("Bad room ID: " + err.Error()),
+ }
+ }
+
+ mReq := messagesReq{
+ ctx: req.Context(),
+ db: db,
+ queryAPI: queryAPI,
+ federation: federation,
+ cfg: cfg,
+ roomID: roomID,
+ from: from,
+ to: to,
+ wasToProvided: wasToProvided,
+ limit: limit,
+ backwardOrdering: backwardOrdering,
+ }
+
+ clientEvents, start, end, err := mReq.retrieveEvents()
+ if err != nil {
+ return httputil.LogThenError(req, err)
+ }
+
+ // Respond with the events.
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: messagesResp{
+ Chunk: clientEvents,
+ Start: start.String(),
+ End: end.String(),
+ },
+ }
+}
+
+// retrieveEvents retrieve events from the local database for a request on
+// /messages. If there's not enough events to retrieve, it asks another
+// homeserver in the room for older events.
+// Returns an error if there was an issue talking to the database or with the
+// remote homeserver.
+func (r *messagesReq) retrieveEvents() (
+ clientEvents []gomatrixserverlib.ClientEvent, start,
+ end *types.PaginationToken, err error,
+) {
+ // Retrieve the events from the local database.
+ streamEvents, err := r.db.GetEventsInRange(
+ r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
+ )
+ if err != nil {
+ return
+ }
+
+ var events []gomatrixserverlib.Event
+
+ // There can be two reasons for streamEvents to be empty: either we've
+ // reached the oldest event in the room (or the most recent one, depending
+ // on the ordering), or we've reached a backward extremity.
+ if len(streamEvents) == 0 {
+ if events, err = r.handleEmptyEventsSlice(); err != nil {
+ return
+ }
+ } else {
+ if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil {
+ return
+ }
+ }
+
+ // If we didn't get any event, we don't need to proceed any further.
+ if len(events) == 0 {
+ return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil
+ }
+
+ // Sort the events to ensure we send them in the right order. We currently
+ // do that based on the event's timestamp.
+ if r.backwardOrdering {
+ sort.SliceStable(events, func(i int, j int) bool {
+ // Backward ordering is antichronological (latest event to oldest
+ // one).
+ return sortEvents(&(events[j]), &(events[i]))
+ })
+ } else {
+ sort.SliceStable(events, func(i int, j int) bool {
+ // Forward ordering is chronological (oldest event to latest one).
+ return sortEvents(&(events[i]), &(events[j]))
+ })
+ }
+
+ // Convert all of the events into client events.
+ clientEvents = gomatrixserverlib.ToClientEvents(events, gomatrixserverlib.FormatAll)
+ // Get the position of the first and the last event in the room's topology.
+ // This position is currently determined by the event's depth, so we could
+ // also use it instead of retrieving from the database. However, if we ever
+ // change the way topological positions are defined (as depth isn't the most
+ // reliable way to define it), it would be easier and less troublesome to
+ // only have to change it in one place, i.e. the database.
+ startPos, err := r.db.EventPositionInTopology(
+ r.ctx, streamEvents[0].EventID(),
+ )
+ if err != nil {
+ return
+ }
+ endPos, err := r.db.EventPositionInTopology(
+ r.ctx, streamEvents[len(streamEvents)-1].EventID(),
+ )
+ if err != nil {
+ return
+ }
+ // Generate pagination tokens to send to the client using the positions
+ // retrieved previously.
+ start = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, startPos, 0,
+ )
+ end = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, endPos, 0,
+ )
+
+ if r.backwardOrdering {
+ // A stream/topological position is a cursor located between two events.
+ // While they are identified in the code by the event on their right (if
+ // we consider a left to right chronological order), tokens need to refer
+ // to them by the event on their left, therefore we need to decrement the
+ // end position we send in the response if we're going backward.
+ end.PDUPosition--
+ }
+
+ // The lowest token value is 1, therefore we need to manually set it to that
+ // value if we're below it.
+ if end.PDUPosition < types.StreamPosition(1) {
+ end.PDUPosition = types.StreamPosition(1)
+ }
+
+ return clientEvents, start, end, err
+}
+
+// handleEmptyEventsSlice handles the case where the initial request to the
+// database returned an empty slice of events. It does so by checking whether
+// the set is empty because we've reached a backward extremity, and if that is
+// the case, by retrieving as much events as requested by backfilling from
+// another homeserver.
+// Returns an error if there was an issue talking with the database or
+// backfilling.
+func (r *messagesReq) handleEmptyEventsSlice() (
+ events []gomatrixserverlib.Event, err error,
+) {
+ backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
+
+ // Check if we have backward extremities for this room.
+ if len(backwardExtremities) > 0 {
+ // If so, retrieve as much events as needed through backfilling.
+ events, err = r.backfill(backwardExtremities, r.limit)
+ if err != nil {
+ return
+ }
+ } else {
+ // If not, it means the slice was empty because we reached the room's
+ // creation, so return an empty slice.
+ events = []gomatrixserverlib.Event{}
+ }
+
+ return
+}
+
+// handleNonEmptyEventsSlice handles the case where the initial request to the
+// database returned a non-empty slice of events. It does so by checking whether
+// events are missing from the expected result, and retrieve missing events
+// through backfilling if needed.
+// Returns an error if there was an issue while backfilling.
+func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent) (
+ events []gomatrixserverlib.Event, err error,
+) {
+ // Check if we have enough events.
+ isSetLargeEnough := true
+ if len(streamEvents) < r.limit {
+ if r.backwardOrdering {
+ if r.wasToProvided {
+ // The condition in the SQL query is a strict "greater than" so
+ // we need to check against to-1.
+ streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)
+ isSetLargeEnough = (r.to.PDUPosition-1 == streamPos)
+ }
+ } else {
+ streamPos := types.StreamPosition(streamEvents[0].StreamPosition)
+ isSetLargeEnough = (r.from.PDUPosition-1 == streamPos)
+ }
+ }
+
+ // Check if the slice contains a backward extremity.
+ backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
+ if err != nil {
+ return
+ }
+
+ // Backfill is needed if we've reached a backward extremity and need more
+ // events. It's only needed if the direction is backward.
+ if len(backwardExtremities) > 0 && !isSetLargeEnough && r.backwardOrdering {
+ var pdus []gomatrixserverlib.Event
+ // Only ask the remote server for enough events to reach the limit.
+ pdus, err = r.backfill(backwardExtremities, r.limit-len(streamEvents))
+ if err != nil {
+ return
+ }
+
+ // Append the PDUs to the list to send back to the client.
+ events = append(events, pdus...)
+ }
+
+ // Append the events ve previously retrieved locally.
+ events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...)
+
+ return
+}
+
+// containsBackwardExtremity checks if a slice of StreamEvent contains a
+// backward extremity. It does so by selecting the earliest event in the slice
+// and by checking the presence in the database of all of its parent events, and
+// considers the event itself a backward extremity if at least one of the parent
+// events doesn't exist in the database.
+// Returns an error if there was an issue with talking to the database.
+func (r *messagesReq) containsBackwardExtremity(events []types.StreamEvent) (bool, error) {
+ // Select the earliest retrieved event.
+ var ev *types.StreamEvent
+ if r.backwardOrdering {
+ ev = &(events[len(events)-1])
+ } else {
+ ev = &(events[0])
+ }
+ // Get the earliest retrieved event's parents.
+ prevIDs := ev.PrevEventIDs()
+ prevs, err := r.db.Events(r.ctx, prevIDs)
+ if err != nil {
+ return false, nil
+ }
+ // Check if we have all of the events we requested. If not, it means we've
+ // reached a backward extremity.
+ var eventInDB bool
+ var id string
+ // Iterate over the IDs we used in the request.
+ for _, id = range prevIDs {
+ eventInDB = false
+ // Iterate over the events we got in response.
+ for _, ev := range prevs {
+ if ev.EventID() == id {
+ eventInDB = true
+ }
+ }
+ // One occurrence of one the event's parents not being present in the
+ // database is enough to say that the event is a backward extremity.
+ if !eventInDB {
+ return true, nil
+ }
+ }
+
+ return false, nil
+}
+
+// backfill performs a backfill request over the federation on another
+// homeserver in the room.
+// See: https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
+// It also stores the PDUs retrieved from the remote homeserver's response to
+// the database.
+// Returns with an empty string if the remote homeserver didn't return with any
+// event, or if there is no remote homeserver to contact.
+// Returns an error if there was an issue with retrieving the list of servers in
+// the room or sending the request.
+func (r *messagesReq) backfill(fromEventIDs []string, limit int) ([]gomatrixserverlib.Event, error) {
+ // Query the list of servers in the room when one of the backward extremities
+ // was sent.
+ var serversResponse api.QueryServersInRoomAtEventResponse
+ serversRequest := api.QueryServersInRoomAtEventRequest{
+ RoomID: r.roomID,
+ EventID: fromEventIDs[0],
+ }
+ if err := r.queryAPI.QueryServersInRoomAtEvent(r.ctx, &serversRequest, &serversResponse); err != nil {
+ return nil, err
+ }
+
+ // Use the first server from the response, except if that server is us.
+ // In that case, use the second one if the roomserver responded with
+ // enough servers. If not, use an empty string to prevent the backfill
+ // from happening as there's no server to direct the request towards.
+ // TODO: Be smarter at selecting the server to direct the request
+ // towards.
+ srvToBackfillFrom := serversResponse.Servers[0]
+ if srvToBackfillFrom == r.cfg.Matrix.ServerName {
+ if len(serversResponse.Servers) > 1 {
+ srvToBackfillFrom = serversResponse.Servers[1]
+ } else {
+ srvToBackfillFrom = gomatrixserverlib.ServerName("")
+ log.Warn("Not enough servers to backfill from")
+ }
+ }
+
+ pdus := make([]gomatrixserverlib.Event, 0)
+
+ // If the roomserver responded with at least one server that isn't us,
+ // send it a request for backfill.
+ if len(srvToBackfillFrom) > 0 {
+ txn, err := r.federation.Backfill(
+ r.ctx, srvToBackfillFrom, r.roomID, limit, fromEventIDs,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ pdus = txn.PDUs
+
+ // Store the events in the database, while marking them as unfit to show
+ // up in responses to sync requests.
+ for _, pdu := range pdus {
+ if _, err = r.db.WriteEvent(
+ r.ctx, &pdu, []gomatrixserverlib.Event{}, []string{}, []string{},
+ nil, true,
+ ); err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ return pdus, nil
+}
+
+// setToDefault returns the default value for the "to" query parameter of a
+// request to /messages if not provided. It defaults to either the earliest
+// topological position (if we're going backward) or to the latest one (if we're
+// going forward).
+// Returns an error if there was an issue with retrieving the latest position
+// from the database
+func setToDefault(
+ ctx context.Context, db storage.Database, backwardOrdering bool,
+ roomID string,
+) (to *types.PaginationToken, err error) {
+ if backwardOrdering {
+ to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 1, 0)
+ } else {
+ var pos types.StreamPosition
+ pos, err = db.MaxTopologicalPosition(ctx, roomID)
+ if err != nil {
+ return
+ }
+
+ to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, 0)
+ }
+
+ return
+}
+
+// sortEvents is a function to give to sort.SliceStable, and compares the
+// timestamp of two Matrix events.
+// Returns true if the first event happened before the second one, false
+// otherwise.
+func sortEvents(e1 *gomatrixserverlib.Event, e2 *gomatrixserverlib.Event) bool {
+ t := e1.OriginServerTS().Time()
+ return e2.OriginServerTS().Time().After(t)
+}
diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go
index bd9389bd..8916565d 100644
--- a/syncapi/routing/routing.go
+++ b/syncapi/routing/routing.go
@@ -22,8 +22,11 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/common/config"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@@ -34,7 +37,12 @@ const pathPrefixR0 = "/_matrix/client/r0"
// Due to Setup being used to call many other functions, a gocyclo nolint is
// applied:
// nolint: gocyclo
-func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, deviceDB *devices.Database) {
+func Setup(
+ apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database,
+ deviceDB *devices.Database, federation *gomatrixserverlib.FederationClient,
+ queryAPI api.RoomserverQueryAPI,
+ cfg *config.Dendrite,
+) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{
@@ -71,4 +79,12 @@ func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, d
}
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"])
})).Methods(http.MethodGet, http.MethodOptions)
+
+ r0mux.Handle("/rooms/{roomID}/messages", common.MakeAuthAPI("room_messages", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
+ vars, err := common.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, queryAPI, cfg)
+ })).Methods(http.MethodGet, http.MethodOptions)
}
diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go
index 36ba88cd..94e6ac41 100644
--- a/syncapi/storage/postgres/account_data_table.go
+++ b/syncapi/storage/postgres/account_data_table.go
@@ -21,6 +21,7 @@ import (
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix"
)
@@ -89,7 +90,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData(
ctx context.Context,
userID, roomID, dataType string,
-) (pos int64, err error) {
+) (pos types.StreamPosition, err error) {
err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos)
return
}
@@ -97,7 +98,7 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountDataInRange(
ctx context.Context,
userID string,
- oldPos, newPos int64,
+ oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrix.FilterPart,
) (data map[string][]string, err error) {
data = make(map[string][]string)
diff --git a/syncapi/storage/postgres/backward_extremities_table.go b/syncapi/storage/postgres/backward_extremities_table.go
new file mode 100644
index 00000000..476d26fa
--- /dev/null
+++ b/syncapi/storage/postgres/backward_extremities_table.go
@@ -0,0 +1,118 @@
+// Copyright 2018 New Vector Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+)
+
+const backwardExtremitiesSchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_backward_extremities (
+ -- The 'room_id' key for the event.
+ room_id TEXT NOT NULL,
+ -- The event ID for the event.
+ event_id TEXT NOT NULL,
+
+ PRIMARY KEY(room_id, event_id)
+);
+`
+
+const insertBackwardExtremitySQL = "" +
+ "INSERT INTO syncapi_backward_extremities (room_id, event_id)" +
+ " VALUES ($1, $2)"
+
+const selectBackwardExtremitiesForRoomSQL = "" +
+ "SELECT event_id FROM syncapi_backward_extremities WHERE room_id = $1"
+
+const isBackwardExtremitySQL = "" +
+ "SELECT EXISTS (" +
+ " SELECT TRUE FROM syncapi_backward_extremities" +
+ " WHERE room_id = $1 AND event_id = $2" +
+ ")"
+
+const deleteBackwardExtremitySQL = "" +
+ "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND event_id = $2"
+
+type backwardExtremitiesStatements struct {
+ insertBackwardExtremityStmt *sql.Stmt
+ selectBackwardExtremitiesForRoomStmt *sql.Stmt
+ isBackwardExtremityStmt *sql.Stmt
+ deleteBackwardExtremityStmt *sql.Stmt
+}
+
+func (s *backwardExtremitiesStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(backwardExtremitiesSchema)
+ if err != nil {
+ return
+ }
+ if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil {
+ return
+ }
+ if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil {
+ return
+ }
+ if s.isBackwardExtremityStmt, err = db.Prepare(isBackwardExtremitySQL); err != nil {
+ return
+ }
+ if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *backwardExtremitiesStatements) insertsBackwardExtremity(
+ ctx context.Context, roomID, eventID string,
+) (err error) {
+ _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID)
+ return
+}
+
+func (s *backwardExtremitiesStatements) selectBackwardExtremitiesForRoom(
+ ctx context.Context, roomID string,
+) (eventIDs []string, err error) {
+ eventIDs = make([]string, 0)
+
+ rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
+ if err != nil {
+ return
+ }
+
+ for rows.Next() {
+ var eID string
+ if err = rows.Scan(&eID); err != nil {
+ return
+ }
+
+ eventIDs = append(eventIDs, eID)
+ }
+
+ return
+}
+
+func (s *backwardExtremitiesStatements) isBackwardExtremity(
+ ctx context.Context, roomID, eventID string,
+) (isBE bool, err error) {
+ err = s.isBackwardExtremityStmt.QueryRowContext(ctx, roomID, eventID).Scan(&isBE)
+ return
+}
+
+func (s *backwardExtremitiesStatements) deleteBackwardExtremity(
+ ctx context.Context, roomID, eventID string,
+) (err error) {
+ _, err = s.insertBackwardExtremityStmt.ExecContext(ctx, roomID, eventID)
+ return
+}
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index 8b208043..816cbb44 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -22,6 +22,7 @@ import (
"github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -87,10 +88,10 @@ const selectStateEventSQL = "" +
const selectEventsWithEventIDsSQL = "" +
// TODO: The session_id and transaction_id blanks are here because otherwise
- // the rowsToStreamEvents expects there to be exactly four columns. We need to
+ // the rowsToStreamEvents expects there to be exactly five columns. We need to
// figure out if these really need to be in the DB, and if so, we need a
// better permanent fix for this. - neilalexander, 2 Jan 2020
- "SELECT added_at, event_json, 0 AS session_id, '' AS transaction_id" +
+ "SELECT added_at, event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" +
" FROM syncapi_current_room_state WHERE event_id = ANY($1)"
type currentRoomStateStatements struct {
@@ -213,7 +214,7 @@ func (s *currentRoomStateStatements) deleteRoomStateByEventID(
func (s *currentRoomStateStatements) upsertRoomState(
ctx context.Context, txn *sql.Tx,
- event gomatrixserverlib.Event, membership *string, addedAt int64,
+ event gomatrixserverlib.Event, membership *string, addedAt types.StreamPosition,
) error {
// Parse content as JSON and search for an "url" key
containsURL := false
@@ -242,7 +243,7 @@ func (s *currentRoomStateStatements) upsertRoomState(
func (s *currentRoomStateStatements) selectEventsWithEventIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
-) ([]streamEvent, error) {
+) ([]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectEventsWithEventIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go
index ced4bfc4..ca4bbeb5 100644
--- a/syncapi/storage/postgres/invites_table.go
+++ b/syncapi/storage/postgres/invites_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/common"
+ "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -86,7 +87,7 @@ func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) {
func (s *inviteEventsStatements) insertInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event,
-) (streamPos int64, err error) {
+) (streamPos types.StreamPosition, err error) {
err = s.insertInviteEventStmt.QueryRowContext(
ctx,
inviteEvent.RoomID(),
@@ -107,7 +108,7 @@ func (s *inviteEventsStatements) deleteInviteEvent(
// selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) selectInviteEventsInRange(
- ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos int64,
+ ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition,
) (map[string]gomatrixserverlib.Event, error) {
stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos)
diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go
index ca271593..be302d73 100644
--- a/syncapi/storage/postgres/output_room_events_table.go
+++ b/syncapi/storage/postgres/output_room_events_table.go
@@ -22,6 +22,7 @@ import (
"sort"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrix"
"github.com/lib/pq"
@@ -36,28 +37,35 @@ CREATE SEQUENCE IF NOT EXISTS syncapi_stream_id;
-- Stores output room events received from the roomserver.
CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
- -- An incrementing ID which denotes the position in the log that this event resides at.
- -- NB: 'serial' makes no guarantees to increment by 1 every time, only that it increments.
- -- This isn't a problem for us since we just want to order by this field.
- id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
- -- The event ID for the event
- event_id TEXT NOT NULL,
- -- The 'room_id' key for the event.
- room_id TEXT NOT NULL,
- -- The JSON for the event. Stored as TEXT because this should be valid UTF-8.
- event_json TEXT NOT NULL,
- -- The event type e.g 'm.room.member'.
- type TEXT NOT NULL,
- -- The 'sender' property of the event.
- sender TEXT NOT NULL,
- -- true if the event content contains a url key.
- contains_url BOOL NOT NULL,
- -- A list of event IDs which represent a delta of added/removed room state. This can be NULL
- -- if there is no delta.
- add_state_ids TEXT[],
- remove_state_ids TEXT[],
- session_id BIGINT, -- The client session that sent the event, if any
- transaction_id TEXT -- The transaction id used to send the event, if any
+ -- An incrementing ID which denotes the position in the log that this event resides at.
+ -- NB: 'serial' makes no guarantees to increment by 1 every time, only that it increments.
+ -- This isn't a problem for us since we just want to order by this field.
+ id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'),
+ -- The event ID for the event
+ event_id TEXT NOT NULL,
+ -- The 'room_id' key for the event.
+ room_id TEXT NOT NULL,
+ -- The JSON for the event. Stored as TEXT because this should be valid UTF-8.
+ event_json TEXT NOT NULL,
+ -- The event type e.g 'm.room.member'.
+ type TEXT NOT NULL,
+ -- The 'sender' property of the event.
+ sender TEXT NOT NULL,
+ -- true if the event content contains a url key.
+ contains_url BOOL NOT NULL,
+ -- A list of event IDs which represent a delta of added/removed room state. This can be NULL
+ -- if there is no delta.
+ add_state_ids TEXT[],
+ remove_state_ids TEXT[],
+ -- The client session that sent the event, if any
+ session_id BIGINT,
+ -- The transaction id used to send the event, if any
+ transaction_id TEXT,
+ -- Should the event be excluded from responses to /sync requests. Useful for
+ -- events retrieved through backfilling that have a position in the stream
+ -- that relates to the moment these were retrieved rather than the moment these
+ -- were emitted.
+ exclude_from_sync BOOL DEFAULT FALSE
);
-- for event selection
CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_events(event_id);
@@ -65,23 +73,33 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev
const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" +
- "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" +
- ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id"
+ "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync" +
+ ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id"
const selectEventsSQL = "" +
- "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" +
- "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4"
+const selectRecentEventsForSyncSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
+ " ORDER BY id DESC LIMIT $4"
+
+const selectEarlyEventsSQL = "" +
+ "SELECT id, event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
+ " WHERE room_id = $1 AND id > $2 AND id <= $3" +
+ " ORDER BY id ASC LIMIT $4"
+
const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
const selectStateInRangeSQL = "" +
- "SELECT id, event_json, add_state_ids, remove_state_ids" +
+ "SELECT id, event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
" AND ( $3::text[] IS NULL OR sender = ANY($3) )" +
@@ -93,11 +111,13 @@ const selectStateInRangeSQL = "" +
" LIMIT $8"
type outputRoomEventsStatements struct {
- insertEventStmt *sql.Stmt
- selectEventsStmt *sql.Stmt
- selectMaxEventIDStmt *sql.Stmt
- selectRecentEventsStmt *sql.Stmt
- selectStateInRangeStmt *sql.Stmt
+ insertEventStmt *sql.Stmt
+ selectEventsStmt *sql.Stmt
+ selectMaxEventIDStmt *sql.Stmt
+ selectRecentEventsStmt *sql.Stmt
+ selectRecentEventsForSyncStmt *sql.Stmt
+ selectEarlyEventsStmt *sql.Stmt
+ selectStateInRangeStmt *sql.Stmt
}
func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
@@ -117,6 +137,12 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
return
}
+ if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil {
+ return
+ }
+ if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil {
+ return
+ }
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return
}
@@ -127,9 +153,9 @@ func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) {
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) selectStateInRange(
- ctx context.Context, txn *sql.Tx, oldPos, newPos int64,
+ ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition,
stateFilterPart *gomatrix.FilterPart,
-) (map[string]map[string]bool, map[string]streamEvent, error) {
+) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
rows, err := stmt.QueryContext(
@@ -149,19 +175,20 @@ func (s *outputRoomEventsStatements) selectStateInRange(
// - For each room ID, build up an array of event IDs which represents cumulative adds/removes
// For each room, map cumulative event IDs to events and return. This may need to a batch SELECT based on event ID
// if they aren't in the event ID cache. We don't handle state deletion yet.
- eventIDToEvent := make(map[string]streamEvent)
+ eventIDToEvent := make(map[string]types.StreamEvent)
// RoomID => A set (map[string]bool) of state event IDs which are between the two positions
stateNeeded := make(map[string]map[string]bool)
for rows.Next() {
var (
- streamPos int64
- eventBytes []byte
- addIDs pq.StringArray
- delIDs pq.StringArray
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ addIDs pq.StringArray
+ delIDs pq.StringArray
)
- if err := rows.Scan(&streamPos, &eventBytes, &addIDs, &delIDs); err != nil {
+ if err := rows.Scan(&streamPos, &eventBytes, &excludeFromSync, &addIDs, &delIDs); err != nil {
return nil, nil, err
}
// Sanity check for deleted state and whine if we see it. We don't need to do anything
@@ -192,9 +219,10 @@ func (s *outputRoomEventsStatements) selectStateInRange(
}
stateNeeded[ev.RoomID()] = needSet
- eventIDToEvent[ev.EventID()] = streamEvent{
- Event: ev,
- streamPosition: streamPos,
+ eventIDToEvent[ev.EventID()] = types.StreamEvent{
+ Event: ev,
+ StreamPosition: streamPos,
+ ExcludeFromSync: excludeFromSync,
}
}
@@ -221,8 +249,8 @@ func (s *outputRoomEventsStatements) selectMaxEventID(
func (s *outputRoomEventsStatements) insertEvent(
ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.Event, addState, removeState []string,
- transactionID *api.TransactionID,
-) (streamPos int64, err error) {
+ transactionID *api.TransactionID, excludeFromSync bool,
+) (streamPos types.StreamPosition, err error) {
var txnID *string
var sessionID *int64
if transactionID != nil {
@@ -251,16 +279,53 @@ func (s *outputRoomEventsStatements) insertEvent(
pq.StringArray(removeState),
sessionID,
txnID,
+ excludeFromSync,
).Scan(&streamPos)
return
}
-// RecentEventsInRoom returns the most recent events in the given room, up to a maximum of 'limit'.
+// selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'.
+// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude
+// from sync.
func (s *outputRoomEventsStatements) selectRecentEvents(
ctx context.Context, txn *sql.Tx,
- roomID string, fromPos, toPos int64, limit int,
-) ([]streamEvent, error) {
- stmt := common.TxStmt(txn, s.selectRecentEventsStmt)
+ roomID string, fromPos, toPos types.StreamPosition, limit int,
+ chronologicalOrder bool, onlySyncEvents bool,
+) ([]types.StreamEvent, error) {
+ var stmt *sql.Stmt
+ if onlySyncEvents {
+ stmt = common.TxStmt(txn, s.selectRecentEventsForSyncStmt)
+ } else {
+ stmt = common.TxStmt(txn, s.selectRecentEventsStmt)
+ }
+
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close() // nolint: errcheck
+ events, err := rowsToStreamEvents(rows)
+ if err != nil {
+ return nil, err
+ }
+ if chronologicalOrder {
+ // The events need to be returned from oldest to latest, which isn't
+ // necessary the way the SQL query returns them, so a sort is necessary to
+ // ensure the events are in the right order in the slice.
+ sort.SliceStable(events, func(i int, j int) bool {
+ return events[i].StreamPosition < events[j].StreamPosition
+ })
+ }
+ return events, nil
+}
+
+// selectEarlyEvents returns the earliest events in the given room, starting
+// from a given position, up to a maximum of 'limit'.
+func (s *outputRoomEventsStatements) selectEarlyEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomID string, fromPos, toPos types.StreamPosition, limit int,
+) ([]types.StreamEvent, error) {
+ stmt := common.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
if err != nil {
return nil, err
@@ -274,16 +339,16 @@ func (s *outputRoomEventsStatements) selectRecentEvents(
// necessarily the way the SQL query returns them, so a sort is necessary to
// ensure the events are in the right order in the slice.
sort.SliceStable(events, func(i int, j int) bool {
- return events[i].streamPosition < events[j].streamPosition
+ return events[i].StreamPosition < events[j].StreamPosition
})
return events, nil
}
-// Events returns the events for the given event IDs. Returns an error if any one of the event IDs given are missing
-// from the database.
+// selectEvents returns the events for the given event IDs. If an event is
+// missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) selectEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
-) ([]streamEvent, error) {
+) ([]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectEventsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
@@ -293,17 +358,18 @@ func (s *outputRoomEventsStatements) selectEvents(
return rowsToStreamEvents(rows)
}
-func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
- var result []streamEvent
+func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
+ var result []types.StreamEvent
for rows.Next() {
var (
- streamPos int64
- eventBytes []byte
- sessionID *int64
- txnID *string
- transactionID *api.TransactionID
+ streamPos types.StreamPosition
+ eventBytes []byte
+ excludeFromSync bool
+ sessionID *int64
+ txnID *string
+ transactionID *api.TransactionID
)
- if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil {
+ if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID); err != nil {
return nil, err
}
// TODO: Handle redacted events
@@ -319,10 +385,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
}
}
- result = append(result, streamEvent{
- Event: ev,
- streamPosition: streamPos,
- transactionID: transactionID,
+ result = append(result, types.StreamEvent{
+ Event: ev,
+ StreamPosition: streamPos,
+ TransactionID: transactionID,
+ ExcludeFromSync: excludeFromSync,
})
}
return result, nil
diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go
new file mode 100644
index 00000000..4a50b9a0
--- /dev/null
+++ b/syncapi/storage/postgres/output_room_events_topology_table.go
@@ -0,0 +1,188 @@
+// Copyright 2018 New Vector Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+const outputRoomEventsTopologySchema = `
+-- Stores output room events received from the roomserver.
+CREATE TABLE IF NOT EXISTS syncapi_output_room_events_topology (
+ -- The event ID for the event.
+ event_id TEXT PRIMARY KEY,
+ -- The place of the event in the room's topology. This can usually be determined
+ -- from the event's depth.
+ topological_position BIGINT NOT NULL,
+ -- The 'room_id' key for the event.
+ room_id TEXT NOT NULL
+);
+-- The topological order will be used in events selection and ordering
+CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_topological_position_idx ON syncapi_output_room_events_topology(topological_position, room_id);
+`
+
+const insertEventInTopologySQL = "" +
+ "INSERT INTO syncapi_output_room_events_topology (event_id, topological_position, room_id)" +
+ " VALUES ($1, $2, $3)" +
+ " ON CONFLICT DO NOTHING"
+
+const selectEventIDsInRangeASCSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
+ " ORDER BY topological_position ASC LIMIT $4"
+
+const selectEventIDsInRangeDESCSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position > $2 AND topological_position <= $3" +
+ " ORDER BY topological_position DESC LIMIT $4"
+
+const selectPositionInTopologySQL = "" +
+ "SELECT topological_position FROM syncapi_output_room_events_topology" +
+ " WHERE event_id = $1"
+
+const selectMaxPositionInTopologySQL = "" +
+ "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1"
+
+const selectEventIDsFromPositionSQL = "" +
+ "SELECT event_id FROM syncapi_output_room_events_topology" +
+ " WHERE room_id = $1 AND topological_position = $2"
+
+type outputRoomEventsTopologyStatements struct {
+ insertEventInTopologyStmt *sql.Stmt
+ selectEventIDsInRangeASCStmt *sql.Stmt
+ selectEventIDsInRangeDESCStmt *sql.Stmt
+ selectPositionInTopologyStmt *sql.Stmt
+ selectMaxPositionInTopologyStmt *sql.Stmt
+ selectEventIDsFromPositionStmt *sql.Stmt
+}
+
+func (s *outputRoomEventsTopologyStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(outputRoomEventsTopologySchema)
+ if err != nil {
+ return
+ }
+ if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil {
+ return
+ }
+ if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil {
+ return
+ }
+ if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil {
+ return
+ }
+ if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
+ return
+ }
+ if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
+ return
+ }
+ if s.selectEventIDsFromPositionStmt, err = db.Prepare(selectEventIDsFromPositionSQL); err != nil {
+ return
+ }
+ return
+}
+
+// insertEventInTopology inserts the given event in the room's topology, based
+// on the event's depth.
+func (s *outputRoomEventsTopologyStatements) insertEventInTopology(
+ ctx context.Context, event *gomatrixserverlib.Event,
+) (err error) {
+ _, err = s.insertEventInTopologyStmt.ExecContext(
+ ctx, event.EventID(), event.Depth(), event.RoomID(),
+ )
+ return
+}
+
+// selectEventIDsInRange selects the IDs of events which positions are within a
+// given range in a given room's topological order.
+// Returns an empty slice if no events match the given range.
+func (s *outputRoomEventsTopologyStatements) selectEventIDsInRange(
+ ctx context.Context, roomID string, fromPos, toPos types.StreamPosition,
+ limit int, chronologicalOrder bool,
+) (eventIDs []string, err error) {
+ // Decide on the selection's order according to whether chronological order
+ // is requested or not.
+ var stmt *sql.Stmt
+ if chronologicalOrder {
+ stmt = s.selectEventIDsInRangeASCStmt
+ } else {
+ stmt = s.selectEventIDsInRangeDESCStmt
+ }
+
+ // Query the event IDs.
+ rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit)
+ if err == sql.ErrNoRows {
+ // If no event matched the request, return an empty slice.
+ return []string{}, nil
+ } else if err != nil {
+ return
+ }
+
+ // Return the IDs.
+ var eventID string
+ for rows.Next() {
+ if err = rows.Scan(&eventID); err != nil {
+ return
+ }
+ eventIDs = append(eventIDs, eventID)
+ }
+
+ return
+}
+
+// selectPositionInTopology returns the position of a given event in the
+// topology of the room it belongs to.
+func (s *outputRoomEventsTopologyStatements) selectPositionInTopology(
+ ctx context.Context, eventID string,
+) (pos types.StreamPosition, err error) {
+ err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos)
+ return
+}
+
+func (s *outputRoomEventsTopologyStatements) selectMaxPositionInTopology(
+ ctx context.Context, roomID string,
+) (pos types.StreamPosition, err error) {
+ err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos)
+ return
+}
+
+// selectEventIDsFromPosition returns the IDs of all events that have a given
+// position in the topology of a given room.
+func (s *outputRoomEventsTopologyStatements) selectEventIDsFromPosition(
+ ctx context.Context, roomID string, pos types.StreamPosition,
+) (eventIDs []string, err error) {
+ // Query the event IDs.
+ rows, err := s.selectEventIDsFromPositionStmt.QueryContext(ctx, roomID, pos)
+ if err == sql.ErrNoRows {
+ // If no event matched the request, return an empty slice.
+ return []string{}, nil
+ } else if err != nil {
+ return
+ }
+ // Return the IDs.
+ var eventID string
+ for rows.Next() {
+ if err = rows.Scan(&eventID); err != nil {
+ return
+ }
+ eventIDs = append(eventIDs, eventID)
+ }
+ return
+}
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index 3a62d136..621aec95 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -20,7 +20,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
- "strconv"
"time"
"github.com/sirupsen/logrus"
@@ -43,29 +42,24 @@ type stateDelta struct {
membership string
// The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta.
- membershipPos int64
+ membershipPos types.StreamPosition
}
-// Same as gomatrixserverlib.Event but also has the PDU stream position for this event.
-type streamEvent struct {
- gomatrixserverlib.Event
- streamPosition int64
- transactionID *api.TransactionID
-}
-
-// SyncServerDatabase represents a sync server datasource which manages
+// SyncServerDatasource represents a sync server datasource which manages
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
db *sql.DB
common.PartitionOffsetStatements
- accountData accountDataStatements
- events outputRoomEventsStatements
- roomstate currentRoomStateStatements
- invites inviteEventsStatements
- typingCache *cache.TypingCache
+ accountData accountDataStatements
+ events outputRoomEventsStatements
+ roomstate currentRoomStateStatements
+ invites inviteEventsStatements
+ typingCache *cache.TypingCache
+ topology outputRoomEventsTopologyStatements
+ backwardExtremities backwardExtremitiesStatements
}
-// NewSyncServerDatabase creates a new sync server database
+// NewSyncServerDatasource creates a new sync server database
func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) {
var d SyncServerDatasource
var err error
@@ -87,6 +81,12 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er
if err := d.invites.prepare(d.db); err != nil {
return nil, err
}
+ if err := d.topology.prepare(d.db); err != nil {
+ return nil, err
+ }
+ if err := d.backwardExtremities.prepare(d.db); err != nil {
+ return nil, err
+ }
d.typingCache = cache.NewTypingCache()
return &d, nil
}
@@ -109,7 +109,46 @@ func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([
// We don't include a device here as we only include transaction IDs in
// incremental syncs.
- return streamEventsToEvents(nil, streamEvents), nil
+ return d.StreamEventsToEvents(nil, streamEvents), nil
+}
+
+func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, ev *gomatrixserverlib.Event) error {
+ // If the event is already known as a backward extremity, don't consider
+ // it as such anymore now that we have it.
+ isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, ev.RoomID(), ev.EventID())
+ if err != nil {
+ return err
+ }
+ if isBackwardExtremity {
+ if err = d.backwardExtremities.deleteBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+
+ // Check if we have all of the event's previous events. If an event is
+ // missing, add it to the room's backward extremities.
+ prevEvents, err := d.events.selectEvents(ctx, nil, ev.PrevEventIDs())
+ if err != nil {
+ return err
+ }
+ var found bool
+ for _, eID := range ev.PrevEventIDs() {
+ found = false
+ for _, prevEv := range prevEvents {
+ if eID == prevEv.EventID() {
+ found = true
+ }
+ }
+
+ // If the event is missing, consider it a backward extremity.
+ if !found {
+ if err = d.backwardExtremities.insertsBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
}
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
@@ -120,16 +159,26 @@ func (d *SyncServerDatasource) WriteEvent(
ev *gomatrixserverlib.Event,
addStateEvents []gomatrixserverlib.Event,
addStateEventIDs, removeStateEventIDs []string,
- transactionID *api.TransactionID,
-) (pduPosition int64, returnErr error) {
+ transactionID *api.TransactionID, excludeFromSync bool,
+) (pduPosition types.StreamPosition, returnErr error) {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
- pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID)
+ pos, err := d.events.insertEvent(
+ ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
+ )
if err != nil {
return err
}
pduPosition = pos
+ if err = d.topology.insertEventInTopology(ctx, ev); err != nil {
+ return err
+ }
+
+ if err = d.handleBackwardExtremities(ctx, ev); err != nil {
+ return err
+ }
+
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event.
return nil
@@ -137,14 +186,15 @@ func (d *SyncServerDatasource) WriteEvent(
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition)
})
- return
+
+ return pduPosition, returnErr
}
func (d *SyncServerDatasource) updateRoomState(
ctx context.Context, txn *sql.Tx,
removedEventIDs []string,
addedEvents []gomatrixserverlib.Event,
- pduPosition int64,
+ pduPosition types.StreamPosition,
) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs {
@@ -196,14 +246,141 @@ func (d *SyncServerDatasource) GetStateEventsForRoom(
return
}
+// GetEventsInRange retrieves all of the events on a given ordering using the
+// given extremities and limit.
+func (d *SyncServerDatasource) GetEventsInRange(
+ ctx context.Context,
+ from, to *types.PaginationToken,
+ roomID string, limit int,
+ backwardOrdering bool,
+) (events []types.StreamEvent, err error) {
+ // If the pagination token's type is types.PaginationTokenTypeTopology, the
+ // events must be retrieved from the rooms' topology table rather than the
+ // table contaning the syncapi server's whole stream of events.
+ if from.Type == types.PaginationTokenTypeTopology {
+ // Determine the backward and forward limit, i.e. the upper and lower
+ // limits to the selection in the room's topology, from the direction.
+ var backwardLimit, forwardLimit types.StreamPosition
+ if backwardOrdering {
+ // Backward ordering is antichronological (latest event to oldest
+ // one).
+ backwardLimit = to.PDUPosition
+ forwardLimit = from.PDUPosition
+ } else {
+ // Forward ordering is chronological (oldest event to latest one).
+ backwardLimit = from.PDUPosition
+ forwardLimit = to.PDUPosition
+ }
+
+ // Select the event IDs from the defined range.
+ var eIDs []string
+ eIDs, err = d.topology.selectEventIDsInRange(
+ ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering,
+ )
+ if err != nil {
+ return
+ }
+
+ // Retrieve the events' contents using their IDs.
+ events, err = d.events.selectEvents(ctx, nil, eIDs)
+ return
+ }
+
+ // If the pagination token's type is types.PaginationTokenTypeStream, the
+ // events must be retrieved from the table contaning the syncapi server's
+ // whole stream of events.
+
+ if backwardOrdering {
+ // When using backward ordering, we want the most recent events first.
+ if events, err = d.events.selectRecentEvents(
+ ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false,
+ ); err != nil {
+ return
+ }
+ } else {
+ // When using forward ordering, we want the least recent events first.
+ if events, err = d.events.selectEarlyEvents(
+ ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit,
+ ); err != nil {
+ return
+ }
+ }
+
+ return
+}
+
// SyncPosition returns the latest positions for syncing.
-func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) {
+func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) {
return d.syncPositionTx(ctx, nil)
}
+// BackwardExtremitiesForRoom returns the event IDs of all of the backward
+// extremities we know of for a given room.
+func (d *SyncServerDatasource) BackwardExtremitiesForRoom(
+ ctx context.Context, roomID string,
+) (backwardExtremities []string, err error) {
+ return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID)
+}
+
+// MaxTopologicalPosition returns the highest topological position for a given
+// room.
+func (d *SyncServerDatasource) MaxTopologicalPosition(
+ ctx context.Context, roomID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectMaxPositionInTopology(ctx, roomID)
+}
+
+// EventsAtTopologicalPosition returns all of the events matching a given
+// position in the topology of a given room.
+func (d *SyncServerDatasource) EventsAtTopologicalPosition(
+ ctx context.Context, roomID string, pos types.StreamPosition,
+) ([]types.StreamEvent, error) {
+ eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos)
+ if err != nil {
+ return nil, err
+ }
+
+ return d.events.selectEvents(ctx, nil, eIDs)
+}
+
+func (d *SyncServerDatasource) EventPositionInTopology(
+ ctx context.Context, eventID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectPositionInTopology(ctx, eventID)
+}
+
+// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
+func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
+ return d.syncStreamPositionTx(ctx, nil)
+}
+
+func (d *SyncServerDatasource) syncStreamPositionTx(
+ ctx context.Context, txn *sql.Tx,
+) (types.StreamPosition, error) {
+ maxID, err := d.events.selectMaxEventID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxAccountDataID > maxID {
+ maxID = maxAccountDataID
+ }
+ maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxInviteID > maxID {
+ maxID = maxInviteID
+ }
+ return types.StreamPosition(maxID), nil
+}
+
func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx,
-) (sp types.SyncPosition, err error) {
+) (sp types.PaginationToken, err error) {
maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil {
@@ -223,10 +400,8 @@ func (d *SyncServerDatasource) syncPositionTx(
if maxInviteID > maxEventID {
maxEventID = maxInviteID
}
- sp.PDUPosition = maxEventID
-
- sp.TypingPosition = d.typingCache.GetLatestSyncPosition()
-
+ sp.PDUPosition = types.StreamPosition(maxEventID)
+ sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition())
return
}
@@ -235,7 +410,7 @@ func (d *SyncServerDatasource) syncPositionTx(
func (d *SyncServerDatasource) addPDUDeltaToResponse(
ctx context.Context,
device authtypes.Device,
- fromPos, toPos int64,
+ fromPos, toPos types.StreamPosition,
numRecentEventsPerRoom int,
wantFullState bool,
res *types.Response,
@@ -287,7 +462,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse(
- since int64,
+ since types.PaginationToken,
joinedRoomIDs []string,
res *types.Response,
) error {
@@ -296,7 +471,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
var err error
for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
- roomID, since,
+ roomID, int64(since.EDUTypingPosition),
); updated {
ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping,
@@ -321,14 +496,14 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse(
- fromPos, toPos types.SyncPosition,
+ fromPos, toPos types.PaginationToken,
joinedRoomIDs []string,
res *types.Response,
) (err error) {
- if fromPos.TypingPosition != toPos.TypingPosition {
+ if fromPos.EDUTypingPosition != toPos.EDUTypingPosition {
err = d.addTypingDeltaToResponse(
- fromPos.TypingPosition, joinedRoomIDs, res,
+ fromPos, joinedRoomIDs, res,
)
}
@@ -343,7 +518,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse(
func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context,
device authtypes.Device,
- fromPos, toPos types.SyncPosition,
+ fromPos, toPos types.PaginationToken,
numRecentEventsPerRoom int,
wantFullState bool,
) (*types.Response, error) {
@@ -383,7 +558,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom int,
) (
res *types.Response,
- toPos types.SyncPosition,
+ toPos types.PaginationToken,
joinedRoomIDs []string,
err error,
) {
@@ -423,27 +598,37 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
}
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
- var recentStreamEvents []streamEvent
+ var recentStreamEvents []types.StreamEvent
recentStreamEvents, err = d.events.selectRecentEvents(
- ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom,
+ ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition,
+ numRecentEventsPerRoom, true, true,
+ //ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom,
)
if err != nil {
return
}
+ // Retrieve the backward topology position, i.e. the position of the
+ // oldest event in the room's topology.
+ var backwardTopologyPos types.StreamPosition
+ backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID())
+ if err != nil {
+ return nil, types.PaginationToken{}, []string{}, err
+ }
+ if backwardTopologyPos-1 <= 0 {
+ backwardTopologyPos = types.StreamPosition(1)
+ } else {
+ backwardTopologyPos = backwardTopologyPos - 1
+ }
+
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs
- recentEvents := streamEventsToEvents(nil, recentStreamEvents)
-
+ recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse()
- if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
- // Use the short form of batch token for prev_batch
- jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
- } else {
- // Use the short form of batch token for prev_batch
- jr.Timeline.PrevBatch = "1"
- }
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true
jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
@@ -471,7 +656,7 @@ func (d *SyncServerDatasource) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse(
- types.SyncPosition{}, toPos, joinedRoomIDs, res,
+ types.PaginationToken{}, toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, err
@@ -496,7 +681,7 @@ var txReadOnlySnapshot = sql.TxOptions{
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
func (d *SyncServerDatasource) GetAccountDataInRange(
- ctx context.Context, userID string, oldPos, newPos int64,
+ ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrix.FilterPart,
) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
@@ -510,7 +695,7 @@ func (d *SyncServerDatasource) GetAccountDataInRange(
// Returns an error if there was an issue with the upsert
func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
-) (int64, error) {
+) (types.StreamPosition, error) {
return d.accountData.insertAccountData(ctx, userID, roomID, dataType)
}
@@ -519,7 +704,7 @@ func (d *SyncServerDatasource) UpsertAccountData(
// Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event,
-) (int64, error) {
+) (types.StreamPosition, error) {
return d.invites.insertInviteEvent(ctx, inviteEvent)
}
@@ -542,26 +727,26 @@ func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallback
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) AddTypingUser(
userID, roomID string, expireTime *time.Time,
-) int64 {
- return d.typingCache.AddTypingUser(userID, roomID, expireTime)
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime))
}
// RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) RemoveTypingUser(
userID, roomID string,
-) int64 {
- return d.typingCache.RemoveUser(userID, roomID)
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID))
}
func (d *SyncServerDatasource) addInvitesToResponse(
ctx context.Context, txn *sql.Tx,
userID string,
- fromPos, toPos int64,
+ fromPos, toPos types.StreamPosition,
res *types.Response,
) error {
invites, err := d.invites.selectInviteEventsInRange(
- ctx, txn, userID, int64(fromPos), int64(toPos),
+ ctx, txn, userID, fromPos, toPos,
)
if err != nil {
return err
@@ -577,12 +762,32 @@ func (d *SyncServerDatasource) addInvitesToResponse(
return nil
}
+// Retrieve the backward topology position, i.e. the position of the
+// oldest event in the room's topology.
+func (d *SyncServerDatasource) getBackwardTopologyPos(
+ ctx context.Context,
+ events []types.StreamEvent,
+) (pos types.StreamPosition, err error) {
+ if len(events) > 0 {
+ pos, err = d.topology.selectPositionInTopology(ctx, events[0].EventID())
+ if err != nil {
+ return
+ }
+ }
+ if pos-1 <= 0 {
+ pos = types.StreamPosition(1)
+ } else {
+ pos = pos - 1
+ }
+ return
+}
+
// addRoomDeltaToResponse adds a room state delta to a sync response
func (d *SyncServerDatasource) addRoomDeltaToResponse(
ctx context.Context,
device *authtypes.Device,
txn *sql.Tx,
- fromPos, toPos int64,
+ fromPos, toPos types.StreamPosition,
delta stateDelta,
numRecentEventsPerRoom int,
res *types.Response,
@@ -598,38 +803,28 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
endPos = delta.membershipPos
}
recentStreamEvents, err := d.events.selectRecentEvents(
- ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom,
+ ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos),
+ numRecentEventsPerRoom, true, true,
)
if err != nil {
return err
}
- recentEvents := streamEventsToEvents(device, recentStreamEvents)
+ recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
- var prevPDUPos int64
-
- if len(recentEvents) == 0 {
- if len(delta.stateEvents) == 0 {
- // Don't bother appending empty room entries
- return nil
- }
-
- // If full_state=true and since is already up to date, then we'll have
- // state events but no recent events.
- prevPDUPos = toPos - 1
- } else {
- prevPDUPos = recentStreamEvents[0].streamPosition - 1
- }
-
- if prevPDUPos <= 0 {
- prevPDUPos = 1
+ var backwardTopologyPos types.StreamPosition
+ backwardTopologyPos, err = d.getBackwardTopologyPos(ctx, recentStreamEvents)
+ if err != nil {
+ return err
}
switch delta.membership {
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
- // Use the short form of batch token for prev_batch
- jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
+
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@@ -640,8 +835,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room.
lr := types.NewLeaveResponse()
- // Use the short form of batch token for prev_batch
- lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
+ lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@@ -656,9 +852,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
func (d *SyncServerDatasource) fetchStateEvents(
ctx context.Context, txn *sql.Tx,
roomIDToEventIDSet map[string]map[string]bool,
- eventIDToEvent map[string]streamEvent,
-) (map[string][]streamEvent, error) {
- stateBetween := make(map[string][]streamEvent)
+ eventIDToEvent map[string]types.StreamEvent,
+) (map[string][]types.StreamEvent, error) {
+ stateBetween := make(map[string][]types.StreamEvent)
missingEvents := make(map[string][]string)
for roomID, ids := range roomIDToEventIDSet {
events := stateBetween[roomID]
@@ -700,7 +896,7 @@ func (d *SyncServerDatasource) fetchStateEvents(
func (d *SyncServerDatasource) fetchMissingStateEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
-) ([]streamEvent, error) {
+) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the
// event.
events, err := d.events.selectEvents(ctx, txn, eventIDs)
@@ -743,7 +939,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents(
// A list of joined room IDs is also returned in case the caller needs it.
func (d *SyncServerDatasource) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx,
- fromPos, toPos int64, userID string,
+ fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
@@ -776,7 +972,7 @@ func (d *SyncServerDatasource) getStateDeltas(
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership == gomatrixserverlib.Join {
// send full room state down instead of a delta
- var s []streamEvent
+ var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart)
if err != nil {
return nil, nil, err
@@ -787,8 +983,8 @@ func (d *SyncServerDatasource) getStateDeltas(
deltas = append(deltas, stateDelta{
membership: membership,
- membershipPos: ev.streamPosition,
- stateEvents: streamEventsToEvents(device, stateStreamEvents),
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
roomID: roomID,
})
break
@@ -804,7 +1000,7 @@ func (d *SyncServerDatasource) getStateDeltas(
for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join,
- stateEvents: streamEventsToEvents(device, state[joinedRoomID]),
+ stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
roomID: joinedRoomID,
})
}
@@ -818,7 +1014,7 @@ func (d *SyncServerDatasource) getStateDeltas(
// updates for other rooms.
func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
ctx context.Context, device *authtypes.Device, txn *sql.Tx,
- fromPos, toPos int64, userID string,
+ fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@@ -837,7 +1033,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
}
deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join,
- stateEvents: streamEventsToEvents(device, s),
+ stateEvents: d.StreamEventsToEvents(device, s),
roomID: joinedRoomID,
})
}
@@ -858,8 +1054,8 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
deltas = append(deltas, stateDelta{
membership: membership,
- membershipPos: ev.streamPosition,
- stateEvents: streamEventsToEvents(device, stateStreamEvents),
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
roomID: roomID,
})
}
@@ -875,29 +1071,29 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrix.FilterPart,
-) ([]streamEvent, error) {
+) ([]types.StreamEvent, error) {
allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
if err != nil {
return nil, err
}
- s := make([]streamEvent, len(allState))
+ s := make([]types.StreamEvent, len(allState))
for i := 0; i < len(s); i++ {
- s[i] = streamEvent{Event: allState[i], streamPosition: 0}
+ s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0}
}
return s, nil
}
-// streamEventsToEvents converts streamEvent to Event. If device is non-nil and
+// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
-func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event {
+func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event {
out := make([]gomatrixserverlib.Event, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].Event
- if device != nil && in[i].transactionID != nil {
- if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID {
+ if device != nil && in[i].TransactionID != nil {
+ if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
- "transaction_id", in[i].transactionID.TransactionID,
+ "transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go
index 5db4b3a1..4e8a2c83 100644
--- a/syncapi/storage/storage.go
+++ b/syncapi/storage/storage.go
@@ -33,19 +33,26 @@ type Database interface {
common.PartitionStorer
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error)
- WriteEvent(ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID) (pduPosition int64, returnErr error)
+ WriteEvent(context.Context, *gomatrixserverlib.Event, []gomatrixserverlib.Event, []string, []string, *api.TransactionID, bool) (types.StreamPosition, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error)
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrix.FilterPart) (stateEvents []gomatrixserverlib.Event, err error)
- SyncPosition(ctx context.Context) (types.SyncPosition, error)
- IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.SyncPosition, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
+ SyncPosition(ctx context.Context) (types.PaginationToken, error)
+ IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error)
- GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos int64, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error)
- UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (int64, error)
- AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (int64, error)
+ GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrix.FilterPart) (map[string][]string, error)
+ UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (types.StreamPosition, error)
+ AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (types.StreamPosition, error)
RetireInviteEvent(ctx context.Context, inviteEventID string) error
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)
- AddTypingUser(userID, roomID string, expireTime *time.Time) int64
- RemoveTypingUser(userID, roomID string) int64
+ AddTypingUser(userID, roomID string, expireTime *time.Time) types.StreamPosition
+ RemoveTypingUser(userID, roomID string) types.StreamPosition
+ GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
+ EventPositionInTopology(ctx context.Context, eventID string) (types.StreamPosition, error)
+ EventsAtTopologicalPosition(ctx context.Context, roomID string, pos types.StreamPosition) ([]types.StreamEvent, error)
+ BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities []string, err error)
+ MaxTopologicalPosition(ctx context.Context, roomID string) (types.StreamPosition, error)
+ StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event
+ SyncStreamPosition(ctx context.Context) (types.StreamPosition, error)
}
// NewPublicRoomsServerDatabase opens a database connection.
diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go
index 548a17ac..aaee49d3 100644
--- a/syncapi/sync/notifier.go
+++ b/syncapi/sync/notifier.go
@@ -36,7 +36,7 @@ type Notifier struct {
// Protects currPos and userStreams.
streamLock *sync.Mutex
// The latest sync position
- currPos types.SyncPosition
+ currPos types.PaginationToken
// A map of user_id => UserStream which can be used to wake a given user's /sync request.
userStreams map[string]*UserStream
// The last time we cleaned out stale entries from the userStreams map
@@ -46,7 +46,7 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
-func NewNotifier(pos types.SyncPosition) *Notifier {
+func NewNotifier(pos types.PaginationToken) *Notifier {
return &Notifier{
currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet),
@@ -68,7 +68,7 @@ func NewNotifier(pos types.SyncPosition) *Notifier {
// event type it handles, leaving other fields as 0.
func (n *Notifier) OnNewEvent(
ev *gomatrixserverlib.Event, roomID string, userIDs []string,
- posUpdate types.SyncPosition,
+ posUpdate types.PaginationToken,
) {
// update the current position then notify relevant /sync streams.
// This needs to be done PRIOR to waking up users as they will read this value.
@@ -151,7 +151,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
}
// CurrentPosition returns the current sync position
-func (n *Notifier) CurrentPosition() types.SyncPosition {
+func (n *Notifier) CurrentPosition() types.PaginationToken {
n.streamLock.Lock()
defer n.streamLock.Unlock()
@@ -173,7 +173,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
}
}
-func (n *Notifier) wakeupUsers(userIDs []string, newPos types.SyncPosition) {
+func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) {
for _, userID := range userIDs {
stream := n.fetchUserStream(userID, false)
if stream != nil {
diff --git a/syncapi/sync/notifier_test.go b/syncapi/sync/notifier_test.go
index 808e07cc..02da0f7e 100644
--- a/syncapi/sync/notifier_test.go
+++ b/syncapi/sync/notifier_test.go
@@ -32,11 +32,11 @@ var (
randomMessageEvent gomatrixserverlib.Event
aliceInviteBobEvent gomatrixserverlib.Event
bobLeaveEvent gomatrixserverlib.Event
- syncPositionVeryOld types.SyncPosition
- syncPositionBefore types.SyncPosition
- syncPositionAfter types.SyncPosition
- syncPositionNewEDU types.SyncPosition
- syncPositionAfter2 types.SyncPosition
+ syncPositionVeryOld types.PaginationToken
+ syncPositionBefore types.PaginationToken
+ syncPositionAfter types.PaginationToken
+ syncPositionNewEDU types.PaginationToken
+ syncPositionAfter2 types.PaginationToken
)
var (
@@ -46,9 +46,9 @@ var (
)
func init() {
- baseSyncPos := types.SyncPosition{
- PDUPosition: 0,
- TypingPosition: 0,
+ baseSyncPos := types.PaginationToken{
+ PDUPosition: 0,
+ EDUTypingPosition: 0,
}
syncPositionVeryOld = baseSyncPos
@@ -61,7 +61,7 @@ func init() {
syncPositionAfter.PDUPosition = 12
syncPositionNewEDU = syncPositionAfter
- syncPositionNewEDU.TypingPosition = 1
+ syncPositionNewEDU.EDUTypingPosition = 1
syncPositionAfter2 = baseSyncPos
syncPositionAfter2.PDUPosition = 13
@@ -119,7 +119,7 @@ func TestImmediateNotification(t *testing.T) {
t.Fatalf("TestImmediateNotification error: %s", err)
}
if pos != syncPositionBefore {
- t.Fatalf("TestImmediateNotification want %d, got %d", syncPositionBefore, pos)
+ t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos)
}
}
@@ -138,7 +138,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndJoinedToRoom error: %s", err)
}
if pos != syncPositionAfter {
- t.Errorf("TestNewEventAndJoinedToRoom want %d, got %d", syncPositionAfter, pos)
+ t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos)
}
wg.Done()
}()
@@ -166,7 +166,7 @@ func TestNewInviteEventForUser(t *testing.T) {
t.Errorf("TestNewInviteEventForUser error: %s", err)
}
if pos != syncPositionAfter {
- t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionAfter, pos)
+ t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos)
}
wg.Done()
}()
@@ -194,7 +194,7 @@ func TestEDUWakeup(t *testing.T) {
t.Errorf("TestNewInviteEventForUser error: %s", err)
}
if pos != syncPositionNewEDU {
- t.Errorf("TestNewInviteEventForUser want %d, got %d", syncPositionNewEDU, pos)
+ t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos)
}
wg.Done()
}()
@@ -222,7 +222,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
t.Errorf("TestMultipleRequestWakeup error: %s", err)
}
if pos != syncPositionAfter {
- t.Errorf("TestMultipleRequestWakeup want %d, got %d", syncPositionAfter, pos)
+ t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos)
}
wg.Done()
}
@@ -262,7 +262,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
}
if pos != syncPositionAfter {
- t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter, pos)
+ t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos)
}
leaveWG.Done()
}()
@@ -281,7 +281,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %s", err)
}
if pos != syncPositionAfter2 {
- t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %d, got %d", syncPositionAfter2, pos)
+ t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos)
}
aliceWG.Done()
}()
@@ -305,14 +305,14 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
time.Sleep(1 * time.Millisecond)
}
-func waitForEvents(n *Notifier, req syncRequest) (types.SyncPosition, error) {
+func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) {
listener := n.GetListener(req)
defer listener.Close()
select {
case <-time.After(5 * time.Second):
- return types.SyncPosition{}, fmt.Errorf(
- "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since,
+ return types.PaginationToken{}, fmt.Errorf(
+ "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since,
)
case <-listener.GetNotifyChannel(*req.since):
p := listener.GetSyncPosition()
@@ -337,7 +337,7 @@ func lockedFetchUserStream(n *Notifier, userID string) *UserStream {
return n.fetchUserStream(userID, true)
}
-func newTestSyncRequest(userID string, since types.SyncPosition) syncRequest {
+func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest {
return syncRequest{
device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute,
diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go
index a5d2f60f..f2e199d2 100644
--- a/syncapi/sync/request.go
+++ b/syncapi/sync/request.go
@@ -16,10 +16,8 @@ package sync
import (
"context"
- "errors"
"net/http"
"strconv"
- "strings"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@@ -38,7 +36,7 @@ type syncRequest struct {
device authtypes.Device
limit int
timeout time.Duration
- since *types.SyncPosition // nil means that no since token was supplied
+ since *types.PaginationToken // nil means that no since token was supplied
wantFullState bool
log *log.Entry
}
@@ -47,7 +45,7 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e
timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false"
- since, err := getSyncStreamPosition(req.URL.Query().Get("since"))
+ since, err := getPaginationToken(req.URL.Query().Get("since"))
if err != nil {
return nil, err
}
@@ -75,41 +73,14 @@ func getTimeout(timeoutMS string) time.Duration {
}
// getSyncStreamPosition tries to parse a 'since' token taken from the API to a
-// types.SyncPosition. If the string is empty then (nil, nil) is returned.
+// types.PaginationToken. If the string is empty then (nil, nil) is returned.
// There are two forms of tokens: The full length form containing all PDU and EDU
// positions separated by "_", and the short form containing only the PDU
// position. Short form can be used for, e.g., `prev_batch` tokens.
-func getSyncStreamPosition(since string) (*types.SyncPosition, error) {
+func getPaginationToken(since string) (*types.PaginationToken, error) {
if since == "" {
return nil, nil
}
- posStrings := strings.Split(since, "_")
- if len(posStrings) != 2 && len(posStrings) != 1 {
- // A token can either be full length or short (PDU-only).
- return nil, errors.New("malformed batch token")
- }
-
- positions := make([]int64, len(posStrings))
- for i, posString := range posStrings {
- pos, err := strconv.ParseInt(posString, 10, 64)
- if err != nil {
- return nil, err
- }
- positions[i] = pos
- }
-
- if len(positions) == 2 {
- // Full length token; construct SyncPosition with every entry in
- // `positions`. These entries must have the same order with the fields
- // in struct SyncPosition, so we disable the govet check below.
- return &types.SyncPosition{ //nolint:govet
- positions[0], positions[1],
- }, nil
- } else {
- // Token with PDU position only
- return &types.SyncPosition{
- PDUPosition: positions[0],
- }, nil
- }
+ return types.NewPaginationTokenFromString(since)
}
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index d75f07e6..5a3ae880 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -130,7 +130,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
}
}
-func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncPosition) (res *types.Response, err error) {
+func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) {
// TODO: handle ignored users
if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
@@ -143,7 +143,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.SyncP
}
accountDataFilter := gomatrix.DefaultFilterPart() // TODO: use filter provided in req instead
- res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter)
+ res, err = rp.appendAccountData(res, req.device.UserID, req, int64(latestPos.PDUPosition), &accountDataFilter)
return
}
@@ -183,7 +183,11 @@ func (rp *RequestPool) appendAccountData(
}
// Sync is not initial, get all account data since the latest sync
- dataTypes, err := rp.db.GetAccountDataInRange(req.ctx, userID, req.since.PDUPosition, currentPos, accountDataFilter)
+ dataTypes, err := rp.db.GetAccountDataInRange(
+ req.ctx, userID,
+ types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos),
+ accountDataFilter,
+ )
if err != nil {
return nil, err
}
diff --git a/syncapi/sync/userstream.go b/syncapi/sync/userstream.go
index beb10e48..6eef8644 100644
--- a/syncapi/sync/userstream.go
+++ b/syncapi/sync/userstream.go
@@ -35,7 +35,7 @@ type UserStream struct {
// Closed when there is an update.
signalChannel chan struct{}
// The last sync position that there may have been an update for the user
- pos types.SyncPosition
+ pos types.PaginationToken
// The last time when we had some listeners waiting
timeOfLastChannel time.Time
// The number of listeners waiting
@@ -51,7 +51,7 @@ type UserStreamListener struct {
}
// NewUserStream creates a new user stream
-func NewUserStream(userID string, currPos types.SyncPosition) *UserStream {
+func NewUserStream(userID string, currPos types.PaginationToken) *UserStream {
return &UserStream{
UserID: userID,
timeOfLastChannel: time.Now(),
@@ -85,7 +85,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
}
// Broadcast a new sync position for this user.
-func (s *UserStream) Broadcast(pos types.SyncPosition) {
+func (s *UserStream) Broadcast(pos types.PaginationToken) {
s.lock.Lock()
defer s.lock.Unlock()
@@ -120,7 +120,7 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
// GetStreamPosition returns last sync position which the UserStream was
// notified about
-func (s *UserStreamListener) GetSyncPosition() types.SyncPosition {
+func (s *UserStreamListener) GetSyncPosition() types.PaginationToken {
s.userStream.lock.Lock()
defer s.userStream.lock.Unlock()
@@ -132,7 +132,7 @@ func (s *UserStreamListener) GetSyncPosition() types.SyncPosition {
// sincePos specifies from which point we want to be notified about. If there
// has already been an update after sincePos we'll return a closed channel
// immediately.
-func (s *UserStreamListener) GetNotifyChannel(sincePos types.SyncPosition) <-chan struct{} {
+func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} {
s.userStream.lock.Lock()
defer s.userStream.lock.Unlock()
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index 4738feea..ecf532ca 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -21,7 +21,9 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/common/basecomponent"
+ "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/syncapi/consumers"
@@ -37,6 +39,8 @@ func SetupSyncAPIComponent(
deviceDB *devices.Database,
accountsDB *accounts.Database,
queryAPI api.RoomserverQueryAPI,
+ federation *gomatrixserverlib.FederationClient,
+ cfg *config.Dendrite,
) {
syncDB, err := storage.NewSyncServerDatasource(string(base.Cfg.Database.SyncAPI))
if err != nil {
@@ -77,5 +81,5 @@ func SetupSyncAPIComponent(
logrus.WithError(err).Panicf("failed to start typing server consumer")
}
- routing.Setup(base.APIMux, requestPool, syncDB, deviceDB)
+ routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, federation, queryAPI, cfg)
}
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index af7ec865..c25a38cd 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -16,45 +16,144 @@ package types
import (
"encoding/json"
+ "errors"
+ "fmt"
"strconv"
+ "strings"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
-// SyncPosition contains the PDU and EDU stream sync positions for a client.
-type SyncPosition struct {
- // PDUPosition is the stream position for PDUs the client is at.
- PDUPosition int64
- // TypingPosition is the client's position for typing notifications.
- TypingPosition int64
+var (
+ // ErrInvalidPaginationTokenType is returned when an attempt at creating a
+ // new instance of PaginationToken with an invalid type (i.e. neither "s"
+ // nor "t").
+ ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)")
+ // ErrInvalidPaginationTokenLen is returned when the pagination token is an
+ // invalid length
+ ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length")
+)
+
+// StreamPosition represents the offset in the sync stream a client is at.
+type StreamPosition int64
+
+// Same as gomatrixserverlib.Event but also has the PDU stream position for this event.
+type StreamEvent struct {
+ gomatrixserverlib.Event
+ StreamPosition StreamPosition
+ TransactionID *api.TransactionID
+ ExcludeFromSync bool
}
-// String implements the Stringer interface.
-func (sp SyncPosition) String() string {
- return strconv.FormatInt(sp.PDUPosition, 10) + "_" +
- strconv.FormatInt(sp.TypingPosition, 10)
+// PaginationTokenType represents the type of a pagination token.
+// It can be either "s" (representing a position in the whole stream of events)
+// or "t" (representing a position in a room's topology/depth).
+type PaginationTokenType string
+
+const (
+ // PaginationTokenTypeStream represents a position in the server's whole
+ // stream of events
+ PaginationTokenTypeStream PaginationTokenType = "s"
+ // PaginationTokenTypeTopology represents a position in a room's topology.
+ PaginationTokenTypeTopology PaginationTokenType = "t"
+)
+
+// PaginationToken represents a pagination token, used for interactions with
+// /sync or /messages, for example.
+type PaginationToken struct {
+ //Position StreamPosition
+ Type PaginationTokenType
+ PDUPosition StreamPosition
+ EDUTypingPosition StreamPosition
}
-// IsAfter returns whether one SyncPosition refers to states newer than another SyncPosition.
-func (sp SyncPosition) IsAfter(other SyncPosition) bool {
- return sp.PDUPosition > other.PDUPosition ||
- sp.TypingPosition > other.TypingPosition
+// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x"
+// represents the type of a pagination token and "yyyy..." the token itself, and
+// parses it in order to create a new instance of PaginationToken. Returns an
+// error if the token couldn't be parsed into an int64, or if the token type
+// isn't a known type (returns ErrInvalidPaginationTokenType in the latter
+// case).
+func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
+ if len(s) == 0 {
+ return nil, ErrInvalidPaginationTokenLen
+ }
+
+ token = new(PaginationToken)
+ var positions []string
+
+ switch t := PaginationTokenType(s[:1]); t {
+ case PaginationTokenTypeStream, PaginationTokenTypeTopology:
+ token.Type = t
+ positions = strings.Split(s[1:], "_")
+ default:
+ token.Type = PaginationTokenTypeStream
+ positions = strings.Split(s, "_")
+ }
+
+ // Try to get the PDU position.
+ if len(positions) >= 1 {
+ if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
+ return nil, err
+ } else if pduPos < 0 {
+ return nil, errors.New("negative PDU position not allowed")
+ } else {
+ token.PDUPosition = StreamPosition(pduPos)
+ }
+ }
+
+ // Try to get the typing position.
+ if len(positions) >= 2 {
+ if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
+ return nil, err
+ } else if typPos < 0 {
+ return nil, errors.New("negative EDU typing position not allowed")
+ } else {
+ token.EDUTypingPosition = StreamPosition(typPos)
+ }
+ }
+
+ return
+}
+
+// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a
+// StreamPosition and returns an instance of PaginationToken.
+func NewPaginationTokenFromTypeAndPosition(
+ t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition,
+) (p *PaginationToken) {
+ return &PaginationToken{
+ Type: t,
+ PDUPosition: pdupos,
+ EDUTypingPosition: typpos,
+ }
}
-// WithUpdates returns a copy of the SyncPosition with updates applied from another SyncPosition.
-// If the latter SyncPosition contains a field that is not 0, it is considered an update,
-// and its value will replace the corresponding value in the SyncPosition on which WithUpdates is called.
-func (sp SyncPosition) WithUpdates(other SyncPosition) SyncPosition {
- ret := sp
+// String translates a PaginationToken to a string of the "xyyyy..." (see
+// NewPaginationToken to know what it represents).
+func (p *PaginationToken) String() string {
+ return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition)
+}
+
+// WithUpdates returns a copy of the PaginationToken with updates applied from another PaginationToken.
+// If the latter PaginationToken contains a field that is not 0, it is considered an update,
+// and its value will replace the corresponding value in the PaginationToken on which WithUpdates is called.
+func (pt *PaginationToken) WithUpdates(other PaginationToken) PaginationToken {
+ ret := *pt
if other.PDUPosition != 0 {
ret.PDUPosition = other.PDUPosition
}
- if other.TypingPosition != 0 {
- ret.TypingPosition = other.TypingPosition
+ if other.EDUTypingPosition != 0 {
+ ret.EDUTypingPosition = other.EDUTypingPosition
}
return ret
}
+// IsAfter returns whether one PaginationToken refers to states newer than another PaginationToken.
+func (sp *PaginationToken) IsAfter(other PaginationToken) bool {
+ return sp.PDUPosition > other.PDUPosition ||
+ sp.EDUTypingPosition > other.EDUTypingPosition
+}
+
// PrevEventRef represents a reference to a previous event in a state event upgrade
type PrevEventRef struct {
PrevContent json.RawMessage `json:"prev_content"`
@@ -79,9 +178,9 @@ type Response struct {
}
// NewResponse creates an empty response with initialised maps.
-func NewResponse(pos SyncPosition) *Response {
+func NewResponse(token PaginationToken) *Response {
res := Response{
- NextBatch: pos.String(),
+ NextBatch: token.String(),
}
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
@@ -96,6 +195,14 @@ func NewResponse(pos SyncPosition) *Response {
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
+ // Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume
+ // we'll always return a stream token.
+ res.NextBatch = NewPaginationTokenFromTypeAndPosition(
+ PaginationTokenTypeStream,
+ StreamPosition(token.PDUPosition),
+ StreamPosition(token.EDUTypingPosition),
+ ).String()
+
return &res
}
diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go
new file mode 100644
index 00000000..f4c84e0d
--- /dev/null
+++ b/syncapi/types/types_test.go
@@ -0,0 +1,52 @@
+package types
+
+import "testing"
+
+func TestNewPaginationTokenFromString(t *testing.T) {
+ shouldPass := map[string]PaginationToken{
+ "2": PaginationToken{
+ Type: PaginationTokenTypeStream,
+ PDUPosition: 2,
+ },
+ "s4": PaginationToken{
+ Type: PaginationTokenTypeStream,
+ PDUPosition: 4,
+ },
+ "s3_1": PaginationToken{
+ Type: PaginationTokenTypeStream,
+ PDUPosition: 3,
+ EDUTypingPosition: 1,
+ },
+ "t3_1_4": PaginationToken{
+ Type: PaginationTokenTypeTopology,
+ PDUPosition: 3,
+ EDUTypingPosition: 1,
+ },
+ }
+
+ shouldFail := []string{
+ "",
+ "s_1",
+ "s_",
+ "a3_4",
+ "b",
+ "b-1",
+ "-4",
+ }
+
+ for test, expected := range shouldPass {
+ result, err := NewPaginationTokenFromString(test)
+ if err != nil {
+ t.Error(err)
+ }
+ if *result != expected {
+ t.Errorf("expected %v but got %v", expected.String(), result.String())
+ }
+ }
+
+ for _, test := range shouldFail {
+ if _, err := NewPaginationTokenFromString(test); err == nil {
+ t.Errorf("input '%v' should have errored but didn't", test)
+ }
+ }
+}