diff options
37 files changed, 1754 insertions, 1522 deletions
diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 796cc61e..735f6718 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -34,6 +34,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -46,7 +47,7 @@ type OutputClientDataConsumer struct { topic string topicReIndex string db storage.Database - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier serverName gomatrixserverlib.ServerName fts *fulltext.Search @@ -61,7 +62,7 @@ func NewOutputClientDataConsumer( nats *nats.Conn, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, fts *fulltext.Search, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index c42e7197..dc7d9e20 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -40,7 +41,7 @@ type OutputKeyChangeEventConsumer struct { topic string db storage.Database notifier *notifier.Notifier - stream types.StreamProvider + stream streams.StreamProvider serverName gomatrixserverlib.ServerName // our server name rsAPI roomserverAPI.SyncRoomserverAPI } @@ -55,7 +56,7 @@ func NewOutputKeyChangeEventConsumer( rsAPI roomserverAPI.SyncRoomserverAPI, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputKeyChangeEventConsumer { s := &OutputKeyChangeEventConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 61bdc13d..145059c2 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -39,7 +40,7 @@ type PresenceConsumer struct { requestTopic string presenceTopic string db storage.Database - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier deviceAPI api.SyncUserAPI cfg *config.SyncAPI @@ -54,7 +55,7 @@ func NewPresenceConsumer( nats *nats.Conn, db storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, deviceAPI api.SyncUserAPI, ) *PresenceConsumer { return &PresenceConsumer{ diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 4379dd13..8aaa6573 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -38,7 +39,7 @@ type OutputReceiptEventConsumer struct { durable string topic string db storage.Database - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier serverName gomatrixserverlib.ServerName } @@ -51,7 +52,7 @@ func NewOutputReceiptEventConsumer( js nats.JetStreamContext, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 3756ad75..e5e8fe29 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -45,8 +46,8 @@ type OutputRoomEventConsumer struct { durable string topic string db storage.Database - pduStream types.StreamProvider - inviteStream types.StreamProvider + pduStream streams.StreamProvider + inviteStream streams.StreamProvider notifier *notifier.Notifier fts *fulltext.Search } @@ -58,8 +59,8 @@ func NewOutputRoomEventConsumer( js nats.JetStreamContext, store storage.Database, notifier *notifier.Notifier, - pduStream types.StreamProvider, - inviteStream types.StreamProvider, + pduStream streams.StreamProvider, + inviteStream streams.StreamProvider, rsAPI api.SyncRoomserverAPI, fts *fulltext.Search, ) *OutputRoomEventConsumer { @@ -449,8 +450,14 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head } stateKey := *event.StateKey() - prevEvent, err := s.db.GetStateEvent( - context.TODO(), event.RoomID(), event.Type(), stateKey, + snapshot, err := s.db.NewDatabaseSnapshot(s.ctx) + if err != nil { + return nil, err + } + defer snapshot.Rollback() // nolint:errcheck + + prevEvent, err := snapshot.GetStateEvent( + s.ctx, event.RoomID(), event.Type(), stateKey, ) if err != nil { return event, err diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index c0b43225..49d84cca 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct { db storage.Database keyAPI keyapi.SyncKeyAPI serverName gomatrixserverlib.ServerName // our server name - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier } @@ -56,7 +57,7 @@ func NewOutputSendToDeviceEventConsumer( store storage.Database, keyAPI keyapi.SyncKeyAPI, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { return &OutputSendToDeviceEventConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/typing.go b/syncapi/consumers/typing.go index 88db80f8..67a26239 100644 --- a/syncapi/consumers/typing.go +++ b/syncapi/consumers/typing.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -36,7 +37,7 @@ type OutputTypingEventConsumer struct { durable string topic string eduCache *caching.EDUCache - stream types.StreamProvider + stream streams.StreamProvider notifier *notifier.Notifier } @@ -48,7 +49,7 @@ func NewOutputTypingEventConsumer( js nats.JetStreamContext, eduCache *caching.EDUCache, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputTypingEventConsumer { return &OutputTypingEventConsumer{ ctx: process.Context(), diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go index c9b96f78..3c73dc1f 100644 --- a/syncapi/consumers/userapi.go +++ b/syncapi/consumers/userapi.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -40,7 +41,7 @@ type OutputNotificationDataConsumer struct { topic string db storage.Database notifier *notifier.Notifier - stream types.StreamProvider + stream streams.StreamProvider } // NewOutputNotificationDataConsumer creates a new consumer. Call @@ -51,7 +52,7 @@ func NewOutputNotificationDataConsumer( js nats.JetStreamContext, store storage.Database, notifier *notifier.Notifier, - stream types.StreamProvider, + stream streams.StreamProvider, ) *OutputNotificationDataConsumer { s := &OutputNotificationDataConsumer{ ctx: process.Context(), diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index e73c004e..bbfe19f4 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -100,7 +100,7 @@ func (ev eventVisibility) allowed() (allowed bool) { // Returns the filtered events and an error, if any. func ApplyHistoryVisibilityFilter( ctx context.Context, - syncDB storage.Database, + syncDB storage.DatabaseTransaction, rsAPI api.SyncRoomserverAPI, events []*gomatrixserverlib.HeaderedEvent, alwaysIncludeEventIDs map[string]struct{}, diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go index 87f0d86d..a8e5bf9a 100644 --- a/syncapi/notifier/notifier.go +++ b/syncapi/notifier/notifier.go @@ -318,13 +318,20 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener { func (n *Notifier) Load(ctx context.Context, db storage.Database) error { n.lock.Lock() defer n.lock.Unlock() - roomToUsers, err := db.AllJoinedUsersInRooms(ctx) + + snapshot, err := db.NewDatabaseSnapshot(ctx) + if err != nil { + return err + } + defer snapshot.Rollback() // nolint:errcheck + + roomToUsers, err := snapshot.AllJoinedUsersInRooms(ctx) if err != nil { return err } n.setUsersJoinedToRooms(roomToUsers) - roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx) + roomToPeekingDevices, err := snapshot.AllPeekingDevicesInRooms(ctx) if err != nil { return err } @@ -338,7 +345,13 @@ func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs [ n.lock.Lock() defer n.lock.Unlock() - roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs) + snapshot, err := db.NewDatabaseSnapshot(ctx) + if err != nil { + return err + } + defer snapshot.Rollback() // nolint:errcheck + + roomToUsers, err := snapshot.AllJoinedUsersInRoom(ctx, roomIDs) if err != nil { return err } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 1ebdfe60..1ce34b85 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -51,6 +51,12 @@ func Context( roomID, eventID string, lazyLoadCache caching.LazyLoadCache, ) util.JSONResponse { + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + defer snapshot.Rollback() // nolint:errcheck + filter, err := parseRoomEventFilter(req) if err != nil { errMsg := "" @@ -97,7 +103,7 @@ func Context( ContainsURL: filter.ContainsURL, } - id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID) + id, requestedEvent, err := snapshot.SelectContextEvent(ctx, roomID, eventID) if err != nil { if err == sql.ErrNoRows { return util.JSONResponse{ @@ -111,7 +117,7 @@ func Context( // verify the user is allowed to see the context for this room/event startTime := time.Now() - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") return jsonerror.InternalServerError() @@ -127,20 +133,20 @@ func Context( } } - eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter) + eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch before events") return jsonerror.InternalServerError() } - _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, filter) + _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter) if err != nil && err != sql.ErrNoRows { logrus.WithError(err).Error("unable to fetch after events") return jsonerror.InternalServerError() } startTime = time.Now() - eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID) + eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") return jsonerror.InternalServerError() @@ -152,7 +158,7 @@ func Context( }).Debug("applied history visibility (context eventsBefore/eventsAfter)") // TODO: Get the actual state at the last event returned by SelectContextAfterEvent - state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil) + state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to fetch current room state") return jsonerror.InternalServerError() @@ -173,7 +179,7 @@ func Context( if len(response.State) > filter.Limit { response.State = response.State[len(response.State)-filter.Limit:] } - start, end, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter) + start, end, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter) if err == nil { response.End = end.String() response.Start = start.String() @@ -188,7 +194,7 @@ func Context( // by combining the events before and after the context event. Returns the filtered events, // and an error, if any. func applyHistoryVisibilityOnContextEvents( - ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI, + ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI, eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent, userID string, ) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) { @@ -205,7 +211,7 @@ func applyHistoryVisibilityOnContextEvents( } allEvents := append(eventsBefore, eventsAfter...) - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context") + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, allEvents, nil, userID, "context") if err != nil { return nil, nil, err } @@ -222,15 +228,15 @@ func applyHistoryVisibilityOnContextEvents( return filteredBefore, filteredAfter, nil } -func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { +func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { if len(startEvents) > 0 { - start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID()) + start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID()) if err != nil { return } } if len(endEvents) > 0 { - end, err = syncDB.EventPositionInTopology(ctx, endEvents[0].EventID()) + end, err = snapshot.EventPositionInTopology(ctx, endEvents[0].EventID()) } return } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 03614302..8f3ed3f5 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/syncapi/internal" @@ -39,6 +40,7 @@ import ( type messagesReq struct { ctx context.Context db storage.Database + snapshot storage.DatabaseTransaction rsAPI api.SyncRoomserverAPI cfg *config.SyncAPI roomID string @@ -70,6 +72,16 @@ func OnIncomingMessagesRequest( ) util.JSONResponse { var err error + // NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we + // expect to be able to write to the database in response to a /messages + // request that requires backfilling from the roomserver or federation. + snapshot, err := db.NewDatabaseTransaction(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + var succeeded bool + defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err) + // check if the user has already forgotten about this room isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI) if err != nil { @@ -132,7 +144,7 @@ func OnIncomingMessagesRequest( } } else { fromStream = &streamToken - from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) + from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) return jsonerror.InternalServerError() @@ -154,7 +166,7 @@ func OnIncomingMessagesRequest( JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()), } } else { - to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) + to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering) if err != nil { logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken) return jsonerror.InternalServerError() @@ -165,7 +177,7 @@ func OnIncomingMessagesRequest( // If "to" isn't provided, it defaults to either the earliest stream // position (if we're going backward) or to the latest one (if we're // going forward). - to, err = setToDefault(req.Context(), db, backwardOrdering, roomID) + to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") return jsonerror.InternalServerError() @@ -186,6 +198,7 @@ func OnIncomingMessagesRequest( mReq := messagesReq{ ctx: req.Context(), db: db, + snapshot: snapshot, rsAPI: rsAPI, cfg: cfg, roomID: roomID, @@ -217,7 +230,7 @@ func OnIncomingMessagesRequest( Start: start.String(), End: end.String(), } - res.applyLazyLoadMembers(req.Context(), db, roomID, device, filter.LazyLoadMembers, lazyLoadCache) + res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache) // If we didn't return any events, set the end to an empty string, so it will be omitted // in the response JSON. @@ -229,6 +242,7 @@ func OnIncomingMessagesRequest( } // Respond with the events. + succeeded = true return util.JSONResponse{ Code: http.StatusOK, JSON: res, @@ -239,7 +253,7 @@ func OnIncomingMessagesRequest( // LazyLoadMembers enabled. func (m *messagesResp) applyLazyLoadMembers( ctx context.Context, - db storage.Database, + db storage.DatabaseTransaction, roomID string, device *userapi.Device, lazyLoad bool, @@ -292,7 +306,7 @@ func (r *messagesReq) retrieveEvents() ( end types.TopologyToken, err error, ) { // Retrieve the events from the local database. - streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) + streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) return @@ -348,7 +362,7 @@ func (r *messagesReq) retrieveEvents() ( // Apply room history visibility filter startTime := time.Now() - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages") + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages") logrus.WithFields(logrus.Fields{ "duration": time.Since(startTime), "room_id": r.roomID, @@ -366,7 +380,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st // else to go. This seems to fix Element iOS from looping on /messages endlessly. end = types.TopologyToken{} } else { - end, err = r.db.EventPositionInTopology( + end, err = r.snapshot.EventPositionInTopology( r.ctx, events[0].EventID(), ) // A stream/topological position is a cursor located between two events. @@ -378,7 +392,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st } } else { start = *r.from - end, err = r.db.EventPositionInTopology( + end, err = r.snapshot.EventPositionInTopology( r.ctx, events[len(events)-1].EventID(), ) } @@ -399,7 +413,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st func (r *messagesReq) handleEmptyEventsSlice() ( events []*gomatrixserverlib.HeaderedEvent, err error, ) { - backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID) // Check if we have backward extremities for this room. if len(backwardExtremities) > 0 { @@ -443,7 +457,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent } // Check if the slice contains a backward extremity. - backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID) + backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID) if err != nil { return } @@ -463,7 +477,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent } // Append the events ve previously retrieved locally. - events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...) + events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) sort.Sort(eventsByDepth(events)) return @@ -553,7 +567,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] // Returns an error if there was an issue with retrieving the latest position // from the database func setToDefault( - ctx context.Context, db storage.Database, backwardOrdering bool, + ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool, roomID string, ) (to types.TopologyToken, err error) { if backwardOrdering { @@ -561,7 +575,7 @@ func setToDefault( // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. to = types.TopologyToken{} } else { - to, err = db.MaxTopologicalPosition(ctx, roomID) + to, err = snapshot.MaxTopologicalPosition(ctx, roomID) } return diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 341efeb1..bac534a2 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -61,8 +61,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts searchReq.SearchCategories.RoomEvents.Filter.Limit = 5 } + snapshot, err := syncDB.NewDatabaseSnapshot(req.Context()) + if err != nil { + return jsonerror.InternalServerError() + } + defer snapshot.Rollback() // nolint:errcheck + // only search rooms the user is actually joined to - joinedRooms, err := syncDB.RoomIDsWithMembership(ctx, device.UserID, "join") + joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join") if err != nil { return jsonerror.InternalServerError() } @@ -161,12 +167,12 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent) for _, event := range evs { - eventsBefore, eventsAfter, err := contextEvents(ctx, syncDB, event, roomFilter, searchReq) + eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq) if err != nil { logrus.WithError(err).Error("failed to get context events") return jsonerror.InternalServerError() } - startToken, endToken, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter) + startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter) if err != nil { logrus.WithError(err).Error("failed to get start/end") return jsonerror.InternalServerError() @@ -176,7 +182,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts for _, ev := range append(eventsBefore, eventsAfter...) { profile, ok := knownUsersProfiles[event.Sender()] if !ok { - stateEvent, err := syncDB.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) + stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender()) if err != nil { logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") continue @@ -209,7 +215,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts groups[event.RoomID()] = roomGroup if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok { stateFilter := gomatrixserverlib.DefaultStateFilter() - state, err := syncDB.CurrentState(ctx, event.RoomID(), &stateFilter, nil) + state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil) if err != nil { logrus.WithError(err).Error("unable to get current state") return jsonerror.InternalServerError() @@ -252,24 +258,24 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts // contextEvents returns the events around a given eventID func contextEvents( ctx context.Context, - syncDB storage.Database, + snapshot storage.DatabaseTransaction, event *gomatrixserverlib.HeaderedEvent, roomFilter *gomatrixserverlib.RoomEventFilter, searchReq SearchRequest, ) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) { - id, _, err := syncDB.SelectContextEvent(ctx, event.RoomID(), event.EventID()) + id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID()) if err != nil { logrus.WithError(err).Error("failed to query context event") return nil, nil, err } roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit - eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter) + eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter) if err != nil { logrus.WithError(err).Error("failed to query before context event") return nil, nil, err } roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit - _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter) + _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter) if err != nil { logrus.WithError(err).Error("failed to query after context event") return nil, nil, err diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index dd03365e..3732e43f 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -17,19 +17,17 @@ package storage import ( "context" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) -type Database interface { - Presence +type DatabaseTransaction interface { SharedUsers - Notifications MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) @@ -37,6 +35,7 @@ type Database interface { MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) + MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) @@ -44,21 +43,16 @@ type Database interface { RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) - RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) - GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) - - InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) + InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) - // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) // AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room. AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) - // AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices. AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) // Events lookups a list of event by their event ID. @@ -67,16 +61,6 @@ type Database interface { // Returns an error if there was a problem talking with the database. // Does not include any transaction IDs in the returned events. Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) - // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races - // when generating the sync stream position for this event. Returns the sync stream position for the inserted event. - // Returns an error if there was a problem inserting this event. - WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent, - addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, - historyVisibility gomatrixserverlib.HistoryVisibility, - ) (types.StreamPosition, error) - // PurgeRoomState completely purges room state from the sync API. This is done when - // receiving an output event that completely resets the state. - PurgeRoomState(ctx context.Context, roomID string) error // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error @@ -91,6 +75,61 @@ type Database interface { // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error) + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) + // EventPositionInTopology returns the depth and stream position of the given event. + EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) + // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. + BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) + // MaxTopologicalPosition returns the highest topological position for a given room. + MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) + // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and + // matches the streamevent.transactionID device then the transaction ID gets + // added to the unsigned section of the output event. + StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent + // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the + // relevant events within the given ranges for the supplied user ID and device ID. + SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) + // GetRoomReceipts gets all receipts for a given roomID + GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) + SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) + SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) + SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error) + IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) + // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found + // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty + // string as the membership. + SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms + GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) + GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) +} + +type Database interface { + Presence + Notifications + + NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) + NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) + + // Events lookups a list of event by their event ID. + // Returns a list of events matching the requested IDs found in the database. + // If an event is not found in the database then it will be omitted from the list. + // Returns an error if there was a problem talking with the database. + // Does not include any transaction IDs in the returned events. + Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) + // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races + // when generating the sync stream position for this event. Returns the sync stream position for the inserted event. + // Returns an error if there was a problem inserting this event. + WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent, + addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, + historyVisibility gomatrixserverlib.HistoryVisibility, + ) (types.StreamPosition, error) + // PurgeRoomState completely purges room state from the sync API. This is done when + // receiving an output event that completely resets the state. + PurgeRoomState(ctx context.Context, roomID string) error // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) @@ -114,21 +153,6 @@ type Database interface { // DeletePeek deletes all peeks for a given room by a given user // Returns an error if there was a problem communicating with the database. DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error) - // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. - GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) - // EventPositionInTopology returns the depth and stream position of the given event. - EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) - // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. - BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) - // MaxTopologicalPosition returns the highest topological position for a given room. - MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) - // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and - // matches the streamevent.transactionID device then the transaction ID gets - // added to the unsigned section of the output event. - StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent - // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the - // relevant events within the given ranges for the supplied user ID and device ID. - SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified @@ -146,29 +170,13 @@ type Database interface { RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error // StoreReceipt stores new receipt events StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) - // GetRoomReceipts gets all receipts for a given roomID - GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) - - SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) - SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) - SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) - - StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error) - - IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found - // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty - // string as the membership. - SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) } type Presence interface { - UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) - PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) - MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) + UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) } type SharedUsers interface { @@ -179,7 +187,4 @@ type SharedUsers interface { type Notifications interface { // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) - - // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms - GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) } diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index f87ccf96..aada70d5 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -55,7 +55,7 @@ const deleteInviteEventSQL = "" + "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id" const selectInviteEventsInRangeSQL = "" + - "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + + "SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" + " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC" @@ -121,23 +121,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { + var lastPos types.StreamPosition stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { - return nil, nil, err + return nil, nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") result := map[string]*gomatrixserverlib.HeaderedEvent{} retired := map[string]*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( + id types.StreamPosition roomID string eventJSON []byte deleted bool ) - if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil { - return nil, nil, err + if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil { + return nil, nil, lastPos, err + } + if id > lastPos { + lastPos = id } // if we have seen this room before, it has a higher stream position and hence takes priority @@ -150,7 +155,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( var event *gomatrixserverlib.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { - return nil, nil, err + return nil, nil, lastPos, err } if deleted { @@ -159,7 +164,10 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( result[roomID] = event } } - return result, retired, rows.Err() + if lastPos == 0 { + lastPos = r.To + } + return result, retired, lastPos, rows.Err() } func (s *inviteEventsStatements) SelectMaxInviteID( diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go new file mode 100644 index 00000000..fb3b295e --- /dev/null +++ b/syncapi/storage/shared/storage_consumer.go @@ -0,0 +1,586 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/tidwall/gjson" + + userapi "github.com/matrix-org/dendrite/userapi/api" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" +) + +// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite +// For now this contains the shared functions +type Database struct { + DB *sql.DB + Writer sqlutil.Writer + Invites tables.Invites + Peeks tables.Peeks + AccountData tables.AccountData + OutputEvents tables.Events + Topology tables.Topology + CurrentRoomState tables.CurrentRoomState + BackwardExtremities tables.BackwardsExtremities + SendToDevice tables.SendToDevice + Filter tables.Filter + Receipts tables.Receipts + Memberships tables.Memberships + NotificationData tables.NotificationData + Ignores tables.Ignores + Presence tables.Presence +} + +func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) { + return d.NewDatabaseTransaction(ctx) + + /* + TODO: Repeatable read is probably the right thing to do here, + but it seems to cause some problems with the invite tests, so + need to investigate that further. + + txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + txn: txn, + }, nil + */ +} + +func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) { + txn, err := d.DB.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + txn: txn, + }, nil +} + +func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { + streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) + if err != nil { + return nil, err + } + + // We don't include a device here as we only include transaction IDs in + // incremental syncs. + return d.StreamEventsToEvents(nil, streamEvents), nil +} + +// AddInviteEvent stores a new invite event for a user. +// If the invite was successfully stored this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) AddInviteEvent( + ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent, +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) + return err + }) + return +} + +// RetireInviteEvent removes an old invite event from the database. +// Returns an error if there was a problem communicating with the database. +func (d *Database) RetireInviteEvent( + ctx context.Context, inviteEventID string, +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) + return err + }) + return +} + +// AddPeek tracks the fact that a user has started peeking. +// If the peek was successfully stored this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) AddPeek( + ctx context.Context, roomID, userID, deviceID string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID) + return err + }) + return +} + +// DeletePeek tracks the fact that a user has stopped peeking from the specified +// device. If the peeks was successfully deleted this returns the stream ID it was +// stored at. Returns an error if there was a problem communicating with the database. +func (d *Database) DeletePeek( + ctx context.Context, roomID, userID, deviceID string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID) + return err + }) + if err == sql.ErrNoRows { + sp = 0 + err = nil + } + return +} + +// DeletePeeks tracks the fact that a user has stopped peeking from all devices +// If the peeks was successfully deleted this returns the stream ID it was stored at. +// Returns an error if there was a problem communicating with the database. +func (d *Database) DeletePeeks( + ctx context.Context, roomID, userID string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID) + return err + }) + if err == sql.ErrNoRows { + sp = 0 + err = nil + } + return +} + +// UpsertAccountData keeps track of new or updated account data, by saving the type +// of the new/updated data, and the user ID and room ID the data is related to (empty) +// room ID means the data isn't specific to any room) +// If no data with the given type, user ID and room ID exists in the database, +// creates a new row, else update the existing one +// Returns an error if there was an issue with the upsert +func (d *Database) UpsertAccountData( + ctx context.Context, userID, roomID, dataType string, +) (sp types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) + return err + }) + return +} + +func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent { + out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[i].HeaderedEvent + if device != nil && in[i].TransactionID != nil { + if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { + err := out[i].SetUnsignedField( + "transaction_id", in[i].TransactionID.TransactionID, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + } + } + } + } + return out +} + +// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of +// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table +// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. +// This function should always be called within a sqlutil.Writer for safety in SQLite. +func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { + if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { + return err + } + + // Check if we have all of the event's previous events. If an event is + // missing, add it to the room's backward extremities. + prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false) + if err != nil { + return err + } + var found bool + for _, eID := range ev.PrevEventIDs() { + found = false + for _, prevEv := range prevEvents { + if eID == prevEv.EventID() { + found = true + } + } + + // If the event is missing, consider it a backward extremity. + if !found { + if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { + return err + } + } + } + + return nil +} + +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + return nil + }) +} + +func (d *Database) WriteEvent( + ctx context.Context, + ev *gomatrixserverlib.HeaderedEvent, + addStateEvents []*gomatrixserverlib.HeaderedEvent, + addStateEventIDs, removeStateEventIDs []string, + transactionID *api.TransactionID, excludeFromSync bool, + historyVisibility gomatrixserverlib.HistoryVisibility, +) (pduPosition types.StreamPosition, returnErr error) { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + ev.Visibility = historyVisibility + pos, err := d.OutputEvents.InsertEvent( + ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility, + ) + if err != nil { + return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) + } + pduPosition = pos + var topoPosition types.StreamPosition + if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { + return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) + } + + if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { + return fmt.Errorf("d.handleBackwardExtremities: %w", err) + } + + if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { + // Nothing to do, the event may have just been a message event. + return nil + } + for i := range addStateEvents { + addStateEvents[i].Visibility = historyVisibility + } + return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition) + }) + + return pduPosition, returnErr +} + +// This function should always be called within a sqlutil.Writer for safety in SQLite. +func (d *Database) updateRoomState( + ctx context.Context, txn *sql.Tx, + removedEventIDs []string, + addedEvents []*gomatrixserverlib.HeaderedEvent, + pduPosition types.StreamPosition, + topoPosition types.StreamPosition, +) error { + // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. + for _, eventID := range removedEventIDs { + if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) + } + } + + for _, event := range addedEvents { + if event.StateKey() == nil { + // ignore non state events + continue + } + var membership *string + if event.Type() == "m.room.member" { + value, err := event.Membership() + if err != nil { + return fmt.Errorf("event.Membership: %w", err) + } + membership = &value + if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil { + return fmt.Errorf("d.Memberships.UpsertMembership: %w", err) + } + } + + if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { + return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) + } + } + + return nil +} + +func (d *Database) GetFilter( + ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, +) error { + return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID) +} + +func (d *Database) PutFilter( + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, +) (string, error) { + var filterID string + var err error + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart) + return err + }) + return filterID, err +} + +func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error { + redactedEvents, err := d.Events(ctx, []string{redactedEventID}) + if err != nil { + return err + } + if len(redactedEvents) == 0 { + logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") + return nil + } + eventToRedact := redactedEvents[0].Unwrap() + redactionEvent := redactedBecause.Unwrap() + if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { + return err + } + + newEvent := eventToRedact.Headered(redactedBecause.RoomVersion) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) + }) + return err +} + +// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. +// Returns a map of room ID to list of events. +func (d *Database) fetchStateEvents( + ctx context.Context, txn *sql.Tx, + roomIDToEventIDSet map[string]map[string]bool, + eventIDToEvent map[string]types.StreamEvent, +) (map[string][]types.StreamEvent, error) { + stateBetween := make(map[string][]types.StreamEvent) + missingEvents := make(map[string][]string) + for roomID, ids := range roomIDToEventIDSet { + events := stateBetween[roomID] + for id, need := range ids { + if !need { + continue // deleted state + } + e, ok := eventIDToEvent[id] + if ok { + events = append(events, e) + } else { + m := missingEvents[roomID] + m = append(m, id) + missingEvents[roomID] = m + } + } + stateBetween[roomID] = events + } + + if len(missingEvents) > 0 { + // This happens when add_state_ids has an event ID which is not in the provided range. + // We need to explicitly fetch them. + allMissingEventIDs := []string{} + for _, missingEvIDs := range missingEvents { + allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) + } + evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) + if err != nil { + return nil, err + } + // we know we got them all otherwise an error would've been returned, so just loop the events + for _, ev := range evs { + roomID := ev.RoomID() + stateBetween[roomID] = append(stateBetween[roomID], ev) + } + } + return stateBetween, nil +} + +func (d *Database) fetchMissingStateEvents( + ctx context.Context, txn *sql.Tx, eventIDs []string, +) ([]types.StreamEvent, error) { + // Fetch from the events table first so we pick up the stream ID for the + // event. + events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false) + if err != nil { + return nil, err + } + + have := map[string]bool{} + for _, event := range events { + have[event.EventID()] = true + } + var missing []string + for _, eventID := range eventIDs { + if !have[eventID] { + missing = append(missing, eventID) + } + } + if len(missing) == 0 { + return events, nil + } + + // If they are missing from the events table then they should be state + // events that we received from outside the main event stream. + // These should be in the room state table. + stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing) + + if err != nil { + return nil, err + } + if len(stateEvents) != len(missing) { + logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing)) + + // TODO: Why is this happening? It's probably the roomserver. Uncomment + // this error again when we work out what it is and fix it, otherwise we + // just end up returning lots of 500s to the client and that breaks + // pretty much everything, rather than just sending what we have. + //return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) + } + events = append(events, stateEvents...) + return events, nil +} + +func (d *Database) StoreNewSendForDeviceMessage( + ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, +) (newPos types.StreamPosition, err error) { + j, err := json.Marshal(event) + if err != nil { + return 0, err + } + // Delegate the database write task to the SendToDeviceWriter. It'll guarantee + // that we don't lock the table for writes in more than one place. + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + newPos, err = d.SendToDevice.InsertSendToDeviceMessage( + ctx, txn, userID, deviceID, string(j), + ) + return err + }) + if err != nil { + return 0, err + } + return newPos, nil +} + +func (d *Database) CleanSendToDeviceUpdates( + ctx context.Context, + userID, deviceID string, before types.StreamPosition, +) (err error) { + if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before) + }); err != nil { + logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID) + return err + } + return nil +} + +// getMembershipFromEvent returns the value of content.membership iff the event is a state event +// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. +func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) { + if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { + return "", "" + } + membership, err := ev.Membership() + if err != nil { + return "", "" + } + prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str + return membership, prevMembership +} + +// StoreReceipt stores user receipts +func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) + return err + }) + return +} + +func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount) + return err + }) + return +} + +func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) +} + +func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) +} +func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { + return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) +} + +func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { + return d.Ignores.SelectIgnores(ctx, nil, userID) +} + +func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores) + }) +} + +func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { + var pos types.StreamPosition + var err error + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync) + return nil + }) + return pos, err +} + +func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, nil, userID) +} + +func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { + return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) +} + +func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) { + return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{ + gomatrixserverlib.MRoomName, + gomatrixserverlib.MRoomTopic, + "m.room.message", + }) +} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go new file mode 100644 index 00000000..a19135a6 --- /dev/null +++ b/syncapi/storage/shared/storage_sync.go @@ -0,0 +1,574 @@ +package shared + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/syncapi/types" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type DatabaseTransaction struct { + *Database + txn *sql.Tx +} + +func (d *DatabaseTransaction) Commit() error { + if d.txn == nil { + return nil + } + return d.txn.Commit() +} + +func (d *DatabaseTransaction) Rollback() error { + if d.txn == nil { + return nil + } + return d.txn.Rollback() +} + +func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { + id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { + id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { + id, err := d.Invites.SelectMaxInviteID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { + id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { + id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { + id, err := d.NotificationData.SelectMaxID(ctx, d.txn) + if err != nil { + return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) + } + return types.StreamPosition(id), nil +} + +func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { + return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs) +} + +func (d *DatabaseTransaction) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { + return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership) +} + +func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { + return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) +} + +func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { + return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) +} + +func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { + return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) +} + +func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { + return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) +} + +func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { + return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r) +} + +func (d *DatabaseTransaction) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { + return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r) +} + +func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { + return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) +} + +// Events lookups a list of event by their event ID. +// Returns a list of events matching the requested IDs found in the database. +// If an event is not found in the database then it will be omitted from the list. +// Returns an error if there was a problem talking with the database. +// Does not include any transaction IDs in the returned events. +func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { + streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false) + if err != nil { + return nil, err + } + + // We don't include a device here as we only include transaction IDs in + // incremental syncs. + return d.StreamEventsToEvents(nil, streamEvents), nil +} + +func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { + return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn) +} + +func (d *DatabaseTransaction) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { + return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs) +} + +func (d *DatabaseTransaction) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { + return d.Peeks.SelectPeekingDevices(ctx, d.txn) +} + +func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { + return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs) +} + +func (d *DatabaseTransaction) GetStateEvent( + ctx context.Context, roomID, evType, stateKey string, +) (*gomatrixserverlib.HeaderedEvent, error) { + return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey) +} + +func (d *DatabaseTransaction) GetStateEventsForRoom( + ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, +) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) + return +} + +// GetAccountDataInRange returns all account data for a given user inserted or +// updated between two given positions +// Returns a map following the format data[roomID] = []dataTypes +// If no data is retrieved, returns an empty map +// If there was an issue with the retrieval, returns an error +func (d *DatabaseTransaction) GetAccountDataInRange( + ctx context.Context, userID string, r types.Range, + accountDataFilterPart *gomatrixserverlib.EventFilter, +) (map[string][]string, types.StreamPosition, error) { + return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart) +} + +func (d *DatabaseTransaction) GetEventsInTopologicalRange( + ctx context.Context, + from, to *types.TopologyToken, + roomID string, + filter *gomatrixserverlib.RoomEventFilter, + backwardOrdering bool, +) (events []types.StreamEvent, err error) { + var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition + if backwardOrdering { + // Backward ordering means the 'from' token has a higher depth than the 'to' token + minDepth = to.Depth + maxDepth = from.Depth + // for cases where we have say 5 events with the same depth, the TopologyToken needs to + // know which of the 5 the client has seen. This is done by using the PDU position. + // Events with the same maxDepth but less than this PDU position will be returned. + maxStreamPosForMaxDepth = from.PDUPosition + } else { + // Forward ordering means the 'from' token has a lower depth than the 'to' token. + minDepth = from.Depth + maxDepth = to.Depth + } + + // Select the event IDs from the defined range. + var eIDs []string + eIDs, err = d.Topology.SelectEventIDsInRange( + ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, + ) + if err != nil { + return + } + + // Retrieve the events' contents using their IDs. + events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true) + return +} + +func (d *DatabaseTransaction) BackwardExtremitiesForRoom( + ctx context.Context, roomID string, +) (backwardExtremities map[string][]string, err error) { + return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) +} + +func (d *DatabaseTransaction) MaxTopologicalPosition( + ctx context.Context, roomID string, +) (types.TopologyToken, error) { + depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) + if err != nil { + return types.TopologyToken{}, err + } + return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil +} + +func (d *DatabaseTransaction) EventPositionInTopology( + ctx context.Context, eventID string, +) (types.TopologyToken, error) { + depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID) + if err != nil { + return types.TopologyToken{}, err + } + return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil +} + +func (d *DatabaseTransaction) StreamToTopologicalPosition( + ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool, +) (types.TopologyToken, error) { + topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering) + switch { + case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward + return types.TopologyToken{PDUPosition: streamPos}, nil + case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward + topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) + if err != nil { + return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) + } + return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + case err != nil: // some other error happened + return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) + default: + return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + } +} + +// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the +// oldest event in the room's topology. +func (d *DatabaseTransaction) GetBackwardTopologyPos( + ctx context.Context, + events []types.StreamEvent, +) (types.TopologyToken, error) { + zeroToken := types.TopologyToken{} + if len(events) == 0 { + return zeroToken, nil + } + pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.txn, events[0].EventID()) + if err != nil { + return zeroToken, err + } + tok := types.TopologyToken{Depth: pos, PDUPosition: spos} + tok.Decrement() + return tok, nil +} + +// GetStateDeltas returns the state deltas between fromPos and toPos, +// exclusive of oldPos, inclusive of newPos, for the rooms in which +// the user has new membership events. +// A list of joined room IDs is also returned in case the caller needs it. +func (d *DatabaseTransaction) GetStateDeltas( + ctx context.Context, device *userapi.Device, + r types.Range, userID string, + stateFilter *gomatrixserverlib.StateFilter, +) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { + // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 + // - Get membership list changes for this user in this sync response + // - For each room which has membership list changes: + // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). + // If it is, then we need to send the full room state down (and 'limited' is always true). + // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. + // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. + // - Get all CURRENTLY joined rooms, and add them to 'joined' block. + + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + + // get all the state events ever (i.e. for all available rooms) between these two positions + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + // find out which rooms this user is peeking, if any. + // We do this before joins so any peeks get overwritten + peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) + if err != nil && err != sql.ErrNoRows { + return nil, nil, err + } + + // add peek blocks + for _, peek := range peeks { + if peek.New { + // send full room state down instead of a delta + var s []types.StreamEvent + s, err = d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter) + if err != nil { + if err == sql.ErrNoRows { + continue + } + return nil, nil, err + } + state[peek.RoomID] = s + } + if !peek.Deleted { + deltas = append(deltas, types.StateDelta{ + Membership: gomatrixserverlib.Peek, + StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), + RoomID: peek.RoomID, + }) + } + } + + // handle newly joined rooms and non-joined rooms + newlyJoinedRooms := make(map[string]bool, len(state)) + for roomID, stateStreamEvents := range state { + for _, ev := range stateStreamEvents { + if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" { + if membership == gomatrixserverlib.Join && prevMembership != membership { + // send full room state down instead of a delta + var s []types.StreamEvent + s, err = d.currentStateStreamEventsForRoom(ctx, roomID, stateFilter) + if err != nil { + if err == sql.ErrNoRows { + continue + } + return nil, nil, err + } + state[roomID] = s + newlyJoinedRooms[roomID] = true + continue // we'll add this room in when we do joined rooms + } + + deltas = append(deltas, types.StateDelta{ + Membership: membership, + MembershipPos: ev.StreamPosition, + StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + RoomID: roomID, + }) + break + } + } + } + + // Add in currently joined rooms + for _, joinedRoomID := range joinedRoomIDs { + deltas = append(deltas, types.StateDelta{ + Membership: gomatrixserverlib.Join, + StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), + RoomID: joinedRoomID, + NewlyJoined: newlyJoinedRooms[joinedRoomID], + }) + } + + return deltas, joinedRoomIDs, nil +} + +// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync +// requests with full_state=true. +// Fetches full state for all joined rooms and uses selectStateInRange to get +// updates for other rooms. +func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( + ctx context.Context, device *userapi.Device, + r types.Range, userID string, + stateFilter *gomatrixserverlib.StateFilter, +) ([]types.StateDelta, []string, error) { + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + + // Use a reasonable initial capacity + deltas := make(map[string]types.StateDelta) + + peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r) + if err != nil && err != sql.ErrNoRows { + return nil, nil, err + } + + // Add full states for all peeking rooms + for _, peek := range peeks { + if !peek.Deleted { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter) + if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } + return nil, nil, stateErr + } + deltas[peek.RoomID] = types.StateDelta{ + Membership: gomatrixserverlib.Peek, + StateEvents: d.StreamEventsToEvents(device, s), + RoomID: peek.RoomID, + } + } + } + + // Get all the state events ever between these two positions + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err + } + + for roomID, stateStreamEvents := range state { + for _, ev := range stateStreamEvents { + if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" { + if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. + deltas[roomID] = types.StateDelta{ + Membership: membership, + MembershipPos: ev.StreamPosition, + StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + RoomID: roomID, + } + } + + break + } + } + } + + // Add full states for all joined rooms + for _, joinedRoomID := range joinedRoomIDs { + s, stateErr := d.currentStateStreamEventsForRoom(ctx, joinedRoomID, stateFilter) + if stateErr != nil { + if stateErr == sql.ErrNoRows { + continue + } + return nil, nil, stateErr + } + deltas[joinedRoomID] = types.StateDelta{ + Membership: gomatrixserverlib.Join, + StateEvents: d.StreamEventsToEvents(device, s), + RoomID: joinedRoomID, + } + } + + // Create a response array. + result := make([]types.StateDelta, len(deltas)) + i := 0 + for _, delta := range deltas { + result[i] = delta + i++ + } + + return result, joinedRoomIDs, nil +} + +func (d *DatabaseTransaction) currentStateStreamEventsForRoom( + ctx context.Context, roomID string, + stateFilter *gomatrixserverlib.StateFilter, +) ([]types.StreamEvent, error) { + allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil) + if err != nil { + return nil, err + } + s := make([]types.StreamEvent, len(allState)) + for i := 0; i < len(s); i++ { + s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} + } + return s, nil +} + +func (d *DatabaseTransaction) SendToDeviceUpdatesForSync( + ctx context.Context, + userID, deviceID string, + from, to types.StreamPosition, +) (types.StreamPosition, []types.SendToDeviceEvent, error) { + // First of all, get our send-to-device updates for this user. + lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.txn, userID, deviceID, from, to) + if err != nil { + return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) + } + // If there's nothing to do then stop here. + if len(events) == 0 { + return to, nil, nil + } + return lastPos, events, nil +} + +func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { + _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos) + return receipts, err +} + +func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { + roomIDs := make([]string, 0, len(rooms)) + for roomID, membership := range rooms { + if membership != gomatrixserverlib.Join { + continue + } + roomIDs = append(roomIDs, roomID) + } + return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) +} + +func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +} + +func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { + return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter) +} + +func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { + return d.Presence.GetMaxPresenceID(ctx, d.txn) +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go deleted file mode 100644 index a05e6880..00000000 --- a/syncapi/storage/shared/syncserver.go +++ /dev/null @@ -1,1103 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package shared - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - - "github.com/tidwall/gjson" - - userapi "github.com/matrix-org/dendrite/userapi/api" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/syncapi/storage/tables" - "github.com/matrix-org/dendrite/syncapi/types" -) - -// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite -// For now this contains the shared functions -type Database struct { - DB *sql.DB - Writer sqlutil.Writer - Invites tables.Invites - Peeks tables.Peeks - AccountData tables.AccountData - OutputEvents tables.Events - Topology tables.Topology - CurrentRoomState tables.CurrentRoomState - BackwardExtremities tables.BackwardsExtremities - SendToDevice tables.SendToDevice - Filter tables.Filter - Receipts tables.Receipts - Memberships tables.Memberships - NotificationData tables.NotificationData - Ignores tables.Ignores - Presence tables.Presence -} - -func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { - return d.DB.BeginTx(ctx, &sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, - }) -} - -func (d *Database) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) { - id, err := d.OutputEvents.SelectMaxEventID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) { - id, err := d.Receipts.SelectMaxReceiptID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) { - id, err := d.Invites.SelectMaxInviteID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) { - id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { - id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) { - id, err := d.NotificationData.SelectMaxID(ctx, nil) - if err != nil { - return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err) - } - return types.StreamPosition(id), nil -} - -func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) -} - -func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) { - return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) -} - -func (d *Database) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) { - return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos) -} - -func (d *Database) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { - return d.Memberships.SelectHeroes(ctx, nil, roomID, userID, memberships) -} - -func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { - return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) -} - -func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { - return d.Topology.SelectPositionInTopology(ctx, nil, eventID) -} - -func (d *Database) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { - return d.Invites.SelectInviteEventsInRange(ctx, nil, targetUserID, r) -} - -func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) { - return d.Peeks.SelectPeeksInRange(ctx, nil, userID, deviceID, r) -} - -func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { - return d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos) -} - -// Events lookups a list of event by their event ID. -// Returns a list of events matching the requested IDs found in the database. -// If an event is not found in the database then it will be omitted from the list. -// Returns an error if there was a problem talking with the database. -// Does not include any transaction IDs in the returned events. -func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false) - if err != nil { - return nil, err - } - - // We don't include a device here as we only include transaction IDs in - // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil -} - -func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsers(ctx, nil) -} - -func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) { - return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, nil, roomIDs) -} - -func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) { - return d.Peeks.SelectPeekingDevices(ctx, nil) -} - -func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) { - return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs) -} - -func (d *Database) GetStateEvent( - ctx context.Context, roomID, evType, stateKey string, -) (*gomatrixserverlib.HeaderedEvent, error) { - return d.CurrentRoomState.SelectStateEvent(ctx, nil, roomID, evType, stateKey) -} - -func (d *Database) GetStateEventsForRoom( - ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, -) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) { - stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter, nil) - return -} - -// AddInviteEvent stores a new invite event for a user. -// If the invite was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *Database) AddInviteEvent( - ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent, -) (sp types.StreamPosition, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) - return err - }) - return -} - -// RetireInviteEvent removes an old invite event from the database. -// Returns an error if there was a problem communicating with the database. -func (d *Database) RetireInviteEvent( - ctx context.Context, inviteEventID string, -) (sp types.StreamPosition, err error) { - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) - return err - }) - return -} - -// AddPeek tracks the fact that a user has started peeking. -// If the peek was successfully stored this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *Database) AddPeek( - ctx context.Context, roomID, userID, deviceID string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID) - return err - }) - return -} - -// DeletePeek tracks the fact that a user has stopped peeking from the specified -// device. If the peeks was successfully deleted this returns the stream ID it was -// stored at. Returns an error if there was a problem communicating with the database. -func (d *Database) DeletePeek( - ctx context.Context, roomID, userID, deviceID string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID) - return err - }) - if err == sql.ErrNoRows { - sp = 0 - err = nil - } - return -} - -// DeletePeeks tracks the fact that a user has stopped peeking from all devices -// If the peeks was successfully deleted this returns the stream ID it was stored at. -// Returns an error if there was a problem communicating with the database. -func (d *Database) DeletePeeks( - ctx context.Context, roomID, userID string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID) - return err - }) - if err == sql.ErrNoRows { - sp = 0 - err = nil - } - return -} - -// GetAccountDataInRange returns all account data for a given user inserted or -// updated between two given positions -// Returns a map following the format data[roomID] = []dataTypes -// If no data is retrieved, returns an empty map -// If there was an issue with the retrieval, returns an error -func (d *Database) GetAccountDataInRange( - ctx context.Context, userID string, r types.Range, - accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, types.StreamPosition, error) { - return d.AccountData.SelectAccountDataInRange(ctx, nil, userID, r, accountDataFilterPart) -} - -// UpsertAccountData keeps track of new or updated account data, by saving the type -// of the new/updated data, and the user ID and room ID the data is related to (empty) -// room ID means the data isn't specific to any room) -// If no data with the given type, user ID and room ID exists in the database, -// creates a new row, else update the existing one -// Returns an error if there was an issue with the upsert -func (d *Database) UpsertAccountData( - ctx context.Context, userID, roomID, dataType string, -) (sp types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) - return err - }) - return -} - -func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent { - out := make([]*gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - -// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of -// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table -// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. -// This function should always be called within a sqlutil.Writer for safety in SQLite. -func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { - if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { - return err - } - - // Check if we have all of the event's previous events. If an event is - // missing, add it to the room's backward extremities. - prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false) - if err != nil { - return err - } - var found bool - for _, eID := range ev.PrevEventIDs() { - found = false - for _, prevEv := range prevEvents { - if eID == prevEv.EventID() { - found = true - } - } - - // If the event is missing, consider it a backward extremity. - if !found { - if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil { - return err - } - } - } - - return nil -} - -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) - } - return nil - }) -} - -func (d *Database) WriteEvent( - ctx context.Context, - ev *gomatrixserverlib.HeaderedEvent, - addStateEvents []*gomatrixserverlib.HeaderedEvent, - addStateEventIDs, removeStateEventIDs []string, - transactionID *api.TransactionID, excludeFromSync bool, - historyVisibility gomatrixserverlib.HistoryVisibility, -) (pduPosition types.StreamPosition, returnErr error) { - returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var err error - ev.Visibility = historyVisibility - pos, err := d.OutputEvents.InsertEvent( - ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility, - ) - if err != nil { - return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) - } - pduPosition = pos - var topoPosition types.StreamPosition - if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { - return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) - } - - if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - return fmt.Errorf("d.handleBackwardExtremities: %w", err) - } - - if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { - // Nothing to do, the event may have just been a message event. - return nil - } - for i := range addStateEvents { - addStateEvents[i].Visibility = historyVisibility - } - return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition) - }) - - return pduPosition, returnErr -} - -// This function should always be called within a sqlutil.Writer for safety in SQLite. -func (d *Database) updateRoomState( - ctx context.Context, txn *sql.Tx, - removedEventIDs []string, - addedEvents []*gomatrixserverlib.HeaderedEvent, - pduPosition types.StreamPosition, - topoPosition types.StreamPosition, -) error { - // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. - for _, eventID := range removedEventIDs { - if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) - } - } - - for _, event := range addedEvents { - if event.StateKey() == nil { - // ignore non state events - continue - } - var membership *string - if event.Type() == "m.room.member" { - value, err := event.Membership() - if err != nil { - return fmt.Errorf("event.Membership: %w", err) - } - membership = &value - if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil { - return fmt.Errorf("d.Memberships.UpsertMembership: %w", err) - } - } - - if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { - return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) - } - } - - return nil -} - -func (d *Database) GetEventsInTopologicalRange( - ctx context.Context, - from, to *types.TopologyToken, - roomID string, - filter *gomatrixserverlib.RoomEventFilter, - backwardOrdering bool, -) (events []types.StreamEvent, err error) { - var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition - if backwardOrdering { - // Backward ordering means the 'from' token has a higher depth than the 'to' token - minDepth = to.Depth - maxDepth = from.Depth - // for cases where we have say 5 events with the same depth, the TopologyToken needs to - // know which of the 5 the client has seen. This is done by using the PDU position. - // Events with the same maxDepth but less than this PDU position will be returned. - maxStreamPosForMaxDepth = from.PDUPosition - } else { - // Forward ordering means the 'from' token has a lower depth than the 'to' token. - minDepth = from.Depth - maxDepth = to.Depth - } - - // Select the event IDs from the defined range. - var eIDs []string - eIDs, err = d.Topology.SelectEventIDsInRange( - ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, - ) - if err != nil { - return - } - - // Retrieve the events' contents using their IDs. - events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true) - return -} - -func (d *Database) BackwardExtremitiesForRoom( - ctx context.Context, roomID string, -) (backwardExtremities map[string][]string, err error) { - return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, nil, roomID) -} - -func (d *Database) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil -} - -func (d *Database) EventPositionInTopology( - ctx context.Context, eventID string, -) (types.TopologyToken, error) { - depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil -} - -func (d *Database) StreamToTopologicalPosition( - ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool, -) (types.TopologyToken, error) { - topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, nil, roomID, streamPos, backwardOrdering) - switch { - case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward - return types.TopologyToken{PDUPosition: streamPos}, nil - case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID) - if err != nil { - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) - } - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil - case err != nil: // some other error happened - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) - default: - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil - } -} - -func (d *Database) GetFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, -) error { - return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID) -} - -func (d *Database) PutFilter( - ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, -) (string, error) { - var filterID string - var err error - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart) - return err - }) - return filterID, err -} - -func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error { - redactedEvents, err := d.Events(ctx, []string{redactedEventID}) - if err != nil { - return err - } - if len(redactedEvents) == 0 { - logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") - return nil - } - eventToRedact := redactedEvents[0].Unwrap() - redactionEvent := redactedBecause.Unwrap() - if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil { - return err - } - - newEvent := eventToRedact.Headered(redactedBecause.RoomVersion) - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent) - }) - return err -} - -// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the -// oldest event in the room's topology. -func (d *Database) GetBackwardTopologyPos( - ctx context.Context, - events []types.StreamEvent, -) (types.TopologyToken, error) { - zeroToken := types.TopologyToken{} - if len(events) == 0 { - return zeroToken, nil - } - pos, spos, err := d.Topology.SelectPositionInTopology(ctx, nil, events[0].EventID()) - if err != nil { - return zeroToken, err - } - tok := types.TopologyToken{Depth: pos, PDUPosition: spos} - tok.Decrement() - return tok, nil -} - -// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database. -// Returns a map of room ID to list of events. -func (d *Database) fetchStateEvents( - ctx context.Context, txn *sql.Tx, - roomIDToEventIDSet map[string]map[string]bool, - eventIDToEvent map[string]types.StreamEvent, -) (map[string][]types.StreamEvent, error) { - stateBetween := make(map[string][]types.StreamEvent) - missingEvents := make(map[string][]string) - for roomID, ids := range roomIDToEventIDSet { - events := stateBetween[roomID] - for id, need := range ids { - if !need { - continue // deleted state - } - e, ok := eventIDToEvent[id] - if ok { - events = append(events, e) - } else { - m := missingEvents[roomID] - m = append(m, id) - missingEvents[roomID] = m - } - } - stateBetween[roomID] = events - } - - if len(missingEvents) > 0 { - // This happens when add_state_ids has an event ID which is not in the provided range. - // We need to explicitly fetch them. - allMissingEventIDs := []string{} - for _, missingEvIDs := range missingEvents { - allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...) - } - evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs) - if err != nil { - return nil, err - } - // we know we got them all otherwise an error would've been returned, so just loop the events - for _, ev := range evs { - roomID := ev.RoomID() - stateBetween[roomID] = append(stateBetween[roomID], ev) - } - } - return stateBetween, nil -} - -func (d *Database) fetchMissingStateEvents( - ctx context.Context, txn *sql.Tx, eventIDs []string, -) ([]types.StreamEvent, error) { - // Fetch from the events table first so we pick up the stream ID for the - // event. - events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false) - if err != nil { - return nil, err - } - - have := map[string]bool{} - for _, event := range events { - have[event.EventID()] = true - } - var missing []string - for _, eventID := range eventIDs { - if !have[eventID] { - missing = append(missing, eventID) - } - } - if len(missing) == 0 { - return events, nil - } - - // If they are missing from the events table then they should be state - // events that we received from outside the main event stream. - // These should be in the room state table. - stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing) - - if err != nil { - return nil, err - } - if len(stateEvents) != len(missing) { - logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing)) - - // TODO: Why is this happening? It's probably the roomserver. Uncomment - // this error again when we work out what it is and fix it, otherwise we - // just end up returning lots of 500s to the client and that breaks - // pretty much everything, rather than just sending what we have. - //return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing)) - } - events = append(events, stateEvents...) - return events, nil -} - -// GetStateDeltas returns the state deltas between fromPos and toPos, -// exclusive of oldPos, inclusive of newPos, for the rooms in which -// the user has new membership events. -// A list of joined room IDs is also returned in case the caller needs it. -func (d *Database) GetStateDeltas( - ctx context.Context, device *userapi.Device, - r types.Range, userID string, - stateFilter *gomatrixserverlib.StateFilter, -) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { - // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 - // - Get membership list changes for this user in this sync response - // - For each room which has membership list changes: - // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO). - // If it is, then we need to send the full room state down (and 'limited' is always true). - // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block. - // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block. - // - Get all CURRENTLY joined rooms, and add them to 'joined' block. - txn, err := d.readOnlySnapshot(ctx) - if err != nil { - return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) - } - var succeeded bool - defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) - - // Look up all memberships for the user. We only care about rooms that a - // user has ever interacted with — joined to, kicked/banned from, left. - memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - allRoomIDs := make([]string, 0, len(memberships)) - joinedRoomIDs := make([]string, 0, len(memberships)) - for roomID, membership := range memberships { - allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { - joinedRoomIDs = append(joinedRoomIDs, roomID) - } - } - - // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - // find out which rooms this user is peeking, if any. - // We do this before joins so any peeks get overwritten - peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil && err != sql.ErrNoRows { - return nil, nil, err - } - - // add peek blocks - for _, peek := range peeks { - if peek.New { - // send full room state down instead of a delta - var s []types.StreamEvent - s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) - if err != nil { - if err == sql.ErrNoRows { - continue - } - return nil, nil, err - } - state[peek.RoomID] = s - } - if !peek.Deleted { - deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Peek, - StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), - RoomID: peek.RoomID, - }) - } - } - - // handle newly joined rooms and non-joined rooms - newlyJoinedRooms := make(map[string]bool, len(state)) - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" { - if membership == gomatrixserverlib.Join && prevMembership != membership { - // send full room state down instead of a delta - var s []types.StreamEvent - s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) - if err != nil { - if err == sql.ErrNoRows { - continue - } - return nil, nil, err - } - state[roomID] = s - newlyJoinedRooms[roomID] = true - continue // we'll add this room in when we do joined rooms - } - - deltas = append(deltas, types.StateDelta{ - Membership: membership, - MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - RoomID: roomID, - }) - break - } - } - } - - // Add in currently joined rooms - for _, joinedRoomID := range joinedRoomIDs { - deltas = append(deltas, types.StateDelta{ - Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]), - RoomID: joinedRoomID, - NewlyJoined: newlyJoinedRooms[joinedRoomID], - }) - } - - succeeded = true - return deltas, joinedRoomIDs, nil -} - -// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync -// requests with full_state=true. -// Fetches full state for all joined rooms and uses selectStateInRange to get -// updates for other rooms. -func (d *Database) GetStateDeltasForFullStateSync( - ctx context.Context, device *userapi.Device, - r types.Range, userID string, - stateFilter *gomatrixserverlib.StateFilter, -) ([]types.StateDelta, []string, error) { - txn, err := d.readOnlySnapshot(ctx) - if err != nil { - return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err) - } - var succeeded bool - defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) - - // Look up all memberships for the user. We only care about rooms that a - // user has ever interacted with — joined to, kicked/banned from, left. - memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - allRoomIDs := make([]string, 0, len(memberships)) - joinedRoomIDs := make([]string, 0, len(memberships)) - for roomID, membership := range memberships { - allRoomIDs = append(allRoomIDs, roomID) - if membership == gomatrixserverlib.Join { - joinedRoomIDs = append(joinedRoomIDs, roomID) - } - } - - // Use a reasonable initial capacity - deltas := make(map[string]types.StateDelta) - - peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r) - if err != nil && err != sql.ErrNoRows { - return nil, nil, err - } - - // Add full states for all peeking rooms - for _, peek := range peeks { - if !peek.Deleted { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter) - if stateErr != nil { - if stateErr == sql.ErrNoRows { - continue - } - return nil, nil, stateErr - } - deltas[peek.RoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Peek, - StateEvents: d.StreamEventsToEvents(device, s), - RoomID: peek.RoomID, - } - } - } - - // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil - } - return nil, nil, err - } - - for roomID, stateStreamEvents := range state { - for _, ev := range stateStreamEvents { - if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" { - if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above. - deltas[roomID] = types.StateDelta{ - Membership: membership, - MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), - RoomID: roomID, - } - } - - break - } - } - } - - // Add full states for all joined rooms - for _, joinedRoomID := range joinedRoomIDs { - s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) - if stateErr != nil { - if stateErr == sql.ErrNoRows { - continue - } - return nil, nil, stateErr - } - deltas[joinedRoomID] = types.StateDelta{ - Membership: gomatrixserverlib.Join, - StateEvents: d.StreamEventsToEvents(device, s), - RoomID: joinedRoomID, - } - } - - // Create a response array. - result := make([]types.StateDelta, len(deltas)) - i := 0 - for _, delta := range deltas { - result[i] = delta - i++ - } - - succeeded = true - return result, joinedRoomIDs, nil -} - -func (d *Database) currentStateStreamEventsForRoom( - ctx context.Context, txn *sql.Tx, roomID string, - stateFilter *gomatrixserverlib.StateFilter, -) ([]types.StreamEvent, error) { - allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter, nil) - if err != nil { - return nil, err - } - s := make([]types.StreamEvent, len(allState)) - for i := 0; i < len(s); i++ { - s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0} - } - return s, nil -} - -func (d *Database) StoreNewSendForDeviceMessage( - ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, -) (newPos types.StreamPosition, err error) { - j, err := json.Marshal(event) - if err != nil { - return 0, err - } - // Delegate the database write task to the SendToDeviceWriter. It'll guarantee - // that we don't lock the table for writes in more than one place. - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - newPos, err = d.SendToDevice.InsertSendToDeviceMessage( - ctx, txn, userID, deviceID, string(j), - ) - return err - }) - if err != nil { - return 0, err - } - return newPos, nil -} - -func (d *Database) SendToDeviceUpdatesForSync( - ctx context.Context, - userID, deviceID string, - from, to types.StreamPosition, -) (types.StreamPosition, []types.SendToDeviceEvent, error) { - // First of all, get our send-to-device updates for this user. - lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to) - if err != nil { - return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) - } - // If there's nothing to do then stop here. - if len(events) == 0 { - return to, nil, nil - } - return lastPos, events, nil -} - -func (d *Database) CleanSendToDeviceUpdates( - ctx context.Context, - userID, deviceID string, before types.StreamPosition, -) (err error) { - if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before) - }); err != nil { - logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID) - return err - } - return nil -} - -// getMembershipFromEvent returns the value of content.membership iff the event is a state event -// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned. -func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) { - if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) { - return "", "" - } - membership, err := ev.Membership() - if err != nil { - return "", "" - } - prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str - return membership, prevMembership -} - -// StoreReceipt stores user receipts -func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp) - return err - }) - return -} - -func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) { - _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos) - return receipts, err -} - -func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount) - return err - }) - return -} - -func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { - roomIDs := make([]string, 0, len(rooms)) - for roomID, membership := range rooms { - if membership != gomatrixserverlib.Join { - continue - } - roomIDs = append(roomIDs, roomID) - } - return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs) -} - -func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { - return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) -} - -func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) { - return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter) -} -func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) { - return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter) -} - -func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) { - return d.Ignores.SelectIgnores(ctx, nil, userID) -} - -func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores) - }) -} - -func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) { - var pos types.StreamPosition - var err error - _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync) - return nil - }) - return pos, err -} - -func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, nil, userID) -} - -func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { - return d.Presence.GetPresenceAfter(ctx, nil, after, filter) -} - -func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { - return d.Presence.GetMaxPresenceID(ctx, nil) -} - -func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { - return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) -} - -func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) { - return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{ - gomatrixserverlib.MRoomName, - gomatrixserverlib.MRoomTopic, - "m.room.message", - }) -} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index ba6d8126..c4019fed 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -367,7 +367,13 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs( for start < len(eventIDs) { n := minOfInts(len(eventIDs)-start, 999) query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1) - rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + var rows *sql.Rows + var err error + if txn == nil { + rows, err = s.db.QueryContext(ctx, query, iEventIDs[start:start+n]...) + } else { + rows, err = txn.QueryContext(ctx, query, iEventIDs[start:start+n]...) + } if err != nil { return nil, err } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 58ab8461..e2dbcd5c 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -50,7 +50,7 @@ const deleteInviteEventSQL = "" + "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false" const selectInviteEventsInRangeSQL = "" + - "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + + "SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" + " WHERE target_user_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC" @@ -132,23 +132,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent( // active invites for the target user ID in the supplied range. func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range, -) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) { +) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) { + var lastPos types.StreamPosition stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High()) if err != nil { - return nil, nil, err + return nil, nil, lastPos, err } defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed") result := map[string]*gomatrixserverlib.HeaderedEvent{} retired := map[string]*gomatrixserverlib.HeaderedEvent{} for rows.Next() { var ( + id types.StreamPosition roomID string eventJSON []byte deleted bool ) - if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil { - return nil, nil, err + if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil { + return nil, nil, lastPos, err + } + if id > lastPos { + lastPos = id } // if we have seen this room before, it has a higher stream position and hence takes priority @@ -161,15 +166,19 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange( var event *gomatrixserverlib.HeaderedEvent if err := json.Unmarshal(eventJSON, &event); err != nil { - return nil, nil, err + return nil, nil, lastPos, err } + if deleted { retired[roomID] = event } else { result[roomID] = event } } - return result, retired, nil + if lastPos == 0 { + lastPos = r.To + } + return result, retired, lastPos, nil } func (s *inviteEventsStatements) SelectMaxInviteID( diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index a84e2bd1..0879030a 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -49,6 +49,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) return &d, nil } +func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) { + return &shared.DatabaseTransaction{ + Database: &d.Database, + // not setting a transaction because SQLite doesn't support it + }, nil +} + +func (d *SyncServerDatasource) NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) { + return &shared.DatabaseTransaction{ + Database: &d.Database, + // not setting a transaction because SQLite doesn't support it + }, nil +} + func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) { if err = d.streamID.Prepare(d.db); err != nil { return err diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index a62818e9..5ff185a3 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -60,6 +60,17 @@ func TestWriteEvents(t *testing.T) { }) } +func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseTransaction)) { + snapshot, err := db.NewDatabaseSnapshot(ctx) + if err != nil { + t.Fatal(err) + } + f(snapshot) + if err := snapshot.Rollback(); err != nil { + t.Fatal(err) + } +} + // These tests assert basic functionality of RecentEvents for PDUs func TestRecentEventsPDU(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -79,10 +90,13 @@ func TestRecentEventsPDU(t *testing.T) { // dummy room to make sure SQL queries are filtering on room ID MustWriteEvents(t, db, test.NewRoom(t, alice).Events()) - latest, err := db.MaxStreamPositionForPDUs(ctx) - if err != nil { - t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err) - } + var latest types.StreamPosition + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + var err error + if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil { + t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err) + } + }) testCases := []struct { Name string @@ -140,14 +154,19 @@ func TestRecentEventsPDU(t *testing.T) { tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { var filter gomatrixserverlib.RoomEventFilter + var gotEvents []types.StreamEvent + var limited bool filter.Limit = tc.Limit - gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{ - From: tc.From, - To: tc.To, - }, &filter, !tc.ReverseOrder, true) - if err != nil { - st.Fatalf("failed to do sync: %s", err) - } + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + var err error + gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{ + From: tc.From, + To: tc.To, + }, &filter, !tc.ReverseOrder, true) + if err != nil { + st.Fatalf("failed to do sync: %s", err) + } + }) if limited != tc.WantLimited { st.Errorf("got limited=%v want %v", limited, tc.WantLimited) } @@ -178,22 +197,24 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { events := r.Events() _ = MustWriteEvents(t, db, events) - from, err := db.MaxTopologicalPosition(ctx, r.ID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } - t.Logf("max topo pos = %+v", from) - // head towards the beginning of time - to := types.TopologyToken{} + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) + if err != nil { + t.Fatalf("failed to get MaxTopologicalPosition: %s", err) + } + t.Logf("max topo pos = %+v", from) + // head towards the beginning of time + to := types.TopologyToken{} - // backpaginate 5 messages starting at the latest position. - filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} - paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) - if err != nil { - t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) - } - gots := db.StreamEventsToEvents(nil, paginatedEvents) - test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + // backpaginate 5 messages starting at the latest position. + filter := &gomatrixserverlib.RoomEventFilter{Limit: 5} + paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) + if err != nil { + t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) + } + gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) + test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + }) }) } @@ -414,13 +435,16 @@ func TestSendToDeviceBehaviour(t *testing.T) { defer closeBase() // At this point there should be no messages. We haven't sent anything // yet. - _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) - if err != nil { - t.Fatal(err) - } - if len(events) != 0 { - t.Fatal("first call should have no updates") - } + + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 { + t.Fatal("first call should have no updates") + } + }) // Try sending a message. streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{ @@ -432,51 +456,58 @@ func TestSendToDeviceBehaviour(t *testing.T) { t.Fatal(err) } - // At this point we should get exactly one message. We're sending the sync position - // that we were given from the update and the send-to-device update will be updated - // in the database to reflect that this was the sync position we sent the message at. - streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) - if err != nil { - t.Fatal(err) - } - if count := len(events); count != 1 { - t.Fatalf("second call should have one update, got %d", count) - } + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + // At this point we should get exactly one message. We're sending the sync position + // that we were given from the update and the send-to-device update will be updated + // in the database to reflect that this was the sync position we sent the message at. + var events []types.SendToDeviceEvent + streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) + if err != nil { + t.Fatal(err) + } + if count := len(events); count != 1 { + t.Fatalf("second call should have one update, got %d", count) + } + + // At this point we should still have one message because we haven't progressed the + // sync position yet. This is equivalent to the client failing to /sync and retrying + // with the same position. + streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) + if err != nil { + t.Fatal(err) + } + if len(events) != 1 { + t.Fatal("third call should have one update still") + } + }) - // At this point we should still have one message because we haven't progressed the - // sync position yet. This is equivalent to the client failing to /sync and retrying - // with the same position. - streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos) - if err != nil { - t.Fatal(err) - } - if len(events) != 1 { - t.Fatal("third call should have one update still") - } err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos) if err != nil { return } - // At this point we should now have no updates, because we've progressed the sync - // position. Therefore the update from before will not be sent again. - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) - if err != nil { - t.Fatal(err) - } - if len(events) != 0 { - t.Fatal("fourth call should have no updates") - } + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + // At this point we should now have no updates, because we've progressed the sync + // position. Therefore the update from before will not be sent again. + var events []types.SendToDeviceEvent + _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 { + t.Fatal("fourth call should have no updates") + } - // At this point we should still have no updates, because no new updates have been - // sent. - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) - if err != nil { - t.Fatal(err) - } - if len(events) != 0 { - t.Fatal("fifth call should have no updates") - } + // At this point we should still have no updates, because no new updates have been + // sent. + _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10) + if err != nil { + t.Fatal(err) + } + if len(events) != 0 { + t.Fatal("fifth call should have no updates") + } + }) // Send some more messages and verify the ordering is correct ("in order of arrival") var lastPos types.StreamPosition = 0 @@ -492,18 +523,20 @@ func TestSendToDeviceBehaviour(t *testing.T) { lastPos = streamPos } - _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos) - if err != nil { - t.Fatalf("unable to get events: %v", err) - } + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos) + if err != nil { + t.Fatalf("unable to get events: %v", err) + } - for i := 0; i < 10; i++ { - want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i)) - got := events[i].Content - if !bytes.Equal(got, want) { - t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got)) + for i := 0; i < 10; i++ { + want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i)) + got := events[i].Content + if !bytes.Equal(got, want) { + t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got)) + } } - } + }) }) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 89cb537a..2fdc3cfb 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -37,7 +37,7 @@ type Invites interface { DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error) // SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value // for the room. - SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, err error) + SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 0297d5c2..3f2f7d13 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -5,22 +5,25 @@ import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" ) type AccountDataStreamProvider struct { - StreamProvider + DefaultStreamProvider userAPI userapi.SyncUserAPI } -func (p *AccountDataStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *AccountDataStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) p.latestMutex.Lock() defer p.latestMutex.Unlock() - id, err := p.DB.MaxStreamPositionForAccountData(context.Background()) + id, err := snapshot.MaxStreamPositionForAccountData(ctx) if err != nil { panic(err) } @@ -29,13 +32,15 @@ func (p *AccountDataStreamProvider) Setup() { func (p *AccountDataStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *AccountDataStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { @@ -44,7 +49,7 @@ func (p *AccountDataStreamProvider) IncrementalSync( To: to, } - dataTypes, pos, err := p.DB.GetAccountDataInRange( + dataTypes, pos, err := snapshot.GetAccountDataInRange( ctx, req.Device.UserID, r, &req.Filter.AccountData, ) if err != nil { diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 5448ee5b..7996c203 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -6,17 +6,19 @@ import ( keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/internal" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type DeviceListStreamProvider struct { - StreamProvider + DefaultStreamProvider rsAPI api.SyncRoomserverAPI keyAPI keyapi.SyncKeyAPI } func (p *DeviceListStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { return p.LatestPosition(ctx) @@ -24,11 +26,12 @@ func (p *DeviceListStreamProvider) CompleteSync( func (p *DeviceListStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { var err error - to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) + to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { req.Log.WithError(err).Error("internal.DeviceListCatchup failed") return from diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index 925da32f..17b3b843 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -9,20 +9,23 @@ import ( "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type InviteStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *InviteStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *InviteStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) p.latestMutex.Lock() defer p.latestMutex.Unlock() - id, err := p.DB.MaxStreamPositionForInvites(context.Background()) + id, err := snapshot.MaxStreamPositionForInvites(ctx) if err != nil { panic(err) } @@ -31,13 +34,15 @@ func (p *InviteStreamProvider) Setup() { func (p *InviteStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *InviteStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { @@ -46,7 +51,7 @@ func (p *InviteStreamProvider) IncrementalSync( To: to, } - invites, retiredInvites, err := p.DB.InviteEventsInRange( + invites, retiredInvites, maxID, err := snapshot.InviteEventsInRange( ctx, req.Device.UserID, r, ) if err != nil { @@ -86,5 +91,5 @@ func (p *InviteStreamProvider) IncrementalSync( } } - return to + return maxID } diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go index 33872734..5a81fd09 100644 --- a/syncapi/streams/stream_notificationdata.go +++ b/syncapi/streams/stream_notificationdata.go @@ -3,17 +3,23 @@ package streams import ( "context" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type NotificationDataStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *NotificationDataStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *NotificationDataStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForNotificationData(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForNotificationData(ctx) if err != nil { panic(err) } @@ -22,20 +28,22 @@ func (p *NotificationDataStreamProvider) Setup() { func (p *NotificationDataStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *NotificationDataStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, _ types.StreamPosition, ) types.StreamPosition { // Get the unread notifications for rooms in our join response. // This is to ensure clients always have an unread notification section // and can display the correct numbers. - countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) + countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) if err != nil { req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed") return from diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0ab6de88..89c5ba35 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -5,7 +5,6 @@ import ( "database/sql" "fmt" "sort" - "sync" "time" "github.com/matrix-org/dendrite/internal/caching" @@ -18,7 +17,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "go.uber.org/atomic" "github.com/matrix-org/dendrite/syncapi/notifier" ) @@ -33,44 +31,23 @@ const PDU_STREAM_WORKERS = 256 const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8 type PDUStreamProvider struct { - StreamProvider + DefaultStreamProvider - tasks chan func() - workers atomic.Int32 // userID+deviceID -> lazy loading cache lazyLoadCache caching.LazyLoadCache rsAPI roomserverAPI.SyncRoomserverAPI notifier *notifier.Notifier } -func (p *PDUStreamProvider) worker() { - defer p.workers.Dec() - for { - select { - case f := <-p.tasks: - f() - case <-time.After(time.Second * 10): - return - } - } -} - -func (p *PDUStreamProvider) queue(f func()) { - if p.workers.Load() < PDU_STREAM_WORKERS { - p.workers.Inc() - go p.worker() - } - p.tasks <- f -} - -func (p *PDUStreamProvider) Setup() { - p.StreamProvider.Setup() - p.tasks = make(chan func(), PDU_STREAM_QUEUESIZE) +func (p *PDUStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) p.latestMutex.Lock() defer p.latestMutex.Unlock() - id, err := p.DB.MaxStreamPositionForPDUs(context.Background()) + id, err := snapshot.MaxStreamPositionForPDUs(ctx) if err != nil { panic(err) } @@ -79,6 +56,7 @@ func (p *PDUStreamProvider) Setup() { func (p *PDUStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { from := types.StreamPosition(0) @@ -94,7 +72,7 @@ func (p *PDUStreamProvider) CompleteSync( } // Extract room state and recent events for all rooms the user is joined to. - joinedRoomIDs, err := p.DB.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) + joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join) if err != nil { req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed") return from @@ -103,7 +81,7 @@ func (p *PDUStreamProvider) CompleteSync( stateFilter := req.Filter.Room.State eventFilter := req.Filter.Room.Timeline - if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil { + if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil { req.Log.WithError(err).Error("unable to update event filter with ignored users") } @@ -117,33 +95,20 @@ func (p *PDUStreamProvider) CompleteSync( } // Build up a /sync response. Add joined rooms. - var reqMutex sync.Mutex - var reqWaitGroup sync.WaitGroup - reqWaitGroup.Add(len(joinedRoomIDs)) - for _, room := range joinedRoomIDs { - roomID := room - p.queue(func() { - defer reqWaitGroup.Done() - - jr, jerr := p.getJoinResponseForCompleteSync( - ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, - ) - if jerr != nil { - req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") - return - } - - reqMutex.Lock() - defer reqMutex.Unlock() - req.Response.Rooms.Join[roomID] = *jr - req.Rooms[roomID] = gomatrixserverlib.Join - }) + for _, roomID := range joinedRoomIDs { + jr, jerr := p.getJoinResponseForCompleteSync( + ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, + ) + if jerr != nil { + req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") + continue // return from + } + req.Response.Rooms.Join[roomID] = *jr + req.Rooms[roomID] = gomatrixserverlib.Join } - reqWaitGroup.Wait() - // Add peeked rooms. - peeks, err := p.DB.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r) + peeks, err := snapshot.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r) if err != nil { req.Log.WithError(err).Error("p.DB.PeeksInRange failed") return from @@ -152,11 +117,11 @@ func (p *PDUStreamProvider) CompleteSync( if !peek.Deleted { var jr *types.JoinResponse jr, err = p.getJoinResponseForCompleteSync( - ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, + ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") - return from + continue // return from } req.Response.Rooms.Peek[peek.RoomID] = *jr } @@ -167,6 +132,7 @@ func (p *PDUStreamProvider) CompleteSync( func (p *PDUStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) (newPos types.StreamPosition) { @@ -184,12 +150,12 @@ func (p *PDUStreamProvider) IncrementalSync( eventFilter := req.Filter.Room.Timeline if req.WantFullState { - if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") return } } else { - if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") return } @@ -203,7 +169,7 @@ func (p *PDUStreamProvider) IncrementalSync( return to } - if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil { + if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil { req.Log.WithError(err).Error("unable to update event filter with ignored users") } @@ -222,7 +188,7 @@ func (p *PDUStreamProvider) IncrementalSync( } } var pos types.StreamPosition - if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { + if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil { req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed") return to } @@ -244,6 +210,7 @@ func (p *PDUStreamProvider) IncrementalSync( // nolint:gocyclo func (p *PDUStreamProvider) addRoomDeltaToResponse( ctx context.Context, + snapshot storage.DatabaseTransaction, device *userapi.Device, r types.Range, delta types.StateDelta, @@ -260,7 +227,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // This is all "okay" assuming history_visibility == "shared" which it is by default. r.To = delta.MembershipPos } - recentStreamEvents, limited, err := p.DB.RecentEvents( + recentStreamEvents, limited, err := snapshot.RecentEvents( ctx, delta.RoomID, r, eventFilter, true, true, ) @@ -270,9 +237,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) } - recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) + recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back - prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents) + prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents) if err != nil { return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err) } @@ -291,7 +258,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( latestPosition := r.To updateLatestPosition := func(mostRecentEventID string) { var pos types.StreamPosition - if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil { + if _, pos, err = snapshot.PositionInTopology(ctx, mostRecentEventID); err == nil { switch { case r.Backwards && pos < latestPosition: fallthrough @@ -303,7 +270,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( if stateFilter.LazyLoadMembers { delta.StateEvents, err = p.lazyLoadMembers( - ctx, delta.RoomID, true, limited, stateFilter, + ctx, snapshot, delta.RoomID, true, limited, stateFilter, device, recentEvents, delta.StateEvents, ) if err != nil && err != sql.ErrNoRows { @@ -320,7 +287,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } // Applies the history visibility rules - events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) + events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") } @@ -336,7 +303,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case gomatrixserverlib.Join: jr := types.NewJoinResponse() if hasMembershipChange { - p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition) + p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition) } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) @@ -376,7 +343,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( // sure we always return the required events in the timeline. func applyHistoryVisibilityFilter( ctx context.Context, - db storage.Database, + snapshot storage.DatabaseTransaction, rsAPI roomserverAPI.SyncRoomserverAPI, roomID, userID string, limit int, @@ -384,7 +351,7 @@ func applyHistoryVisibilityFilter( ) ([]*gomatrixserverlib.HeaderedEvent, error) { // We need to make sure we always include the latest states events, if they are in the timeline. // We grep at least limit * 2 events, to ensure we really get the needed events. - stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) + stateEvents, err := snapshot.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil) if err != nil { // Not a fatal error, we can continue without the stateEvents, // they are only needed if there are state events in the timeline. @@ -395,7 +362,7 @@ func applyHistoryVisibilityFilter( alwaysIncludeIDs[ev.EventID()] = struct{}{} } startTime := time.Now() - events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") + events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") if err != nil { return nil, err } @@ -408,10 +375,10 @@ func applyHistoryVisibilityFilter( return events, nil } -func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { +func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { // Work out how many members are in the room. - joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) - invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) + joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) + invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) jr.Summary.JoinedMemberCount = &joinedCount jr.Summary.InvitedMemberCount = &invitedCount @@ -439,7 +406,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe } } } - heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) + heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) if err != nil { return } @@ -449,6 +416,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, roomID string, r types.Range, stateFilter *gomatrixserverlib.StateFilter, @@ -460,7 +428,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - recentStreamEvents, limited, err := p.DB.RecentEvents( + recentStreamEvents, limited, err := snapshot.RecentEvents( ctx, roomID, r, eventFilter, true, true, ) if err != nil { @@ -484,7 +452,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } } - stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) + stateEvents, err := snapshot.CurrentState(ctx, roomID, stateFilter, excludingEventIDs) if err != nil { return } @@ -494,7 +462,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( var prevBatch *types.TopologyToken if len(recentStreamEvents) > 0 { var backwardTopologyPos, backwardStreamPos types.StreamPosition - backwardTopologyPos, backwardStreamPos, err = p.DB.PositionInTopology(ctx, recentStreamEvents[0].EventID()) + backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID()) if err != nil { return } @@ -505,18 +473,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( prevBatch.Decrement() } - p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From) + p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From) // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // "Can sync a room with a message with a transaction id" - which does a complete sync to check. - recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents) + recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) events := recentEvents // Only apply history visibility checks if the response is for joined rooms if !isPeek { - events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) + events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents) if err != nil { logrus.WithError(err).Error("unable to apply history visibility filter") } @@ -530,7 +498,8 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( if err != nil { return nil, err } - stateEvents, err = p.lazyLoadMembers(ctx, roomID, + stateEvents, err = p.lazyLoadMembers( + ctx, snapshot, roomID, false, limited, stateFilter, device, recentEvents, stateEvents, ) @@ -549,7 +518,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } func (p *PDUStreamProvider) lazyLoadMembers( - ctx context.Context, roomID string, + ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter, device *userapi.Device, timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent, @@ -598,7 +567,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( filter.Limit = stateFilter.Limit filter.Senders = &wantUsers filter.Types = &[]string{gomatrixserverlib.MRoomMember} - memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter) + memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) if err != nil { return stateEvents, err } @@ -612,8 +581,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( // addIgnoredUsersToFilter adds ignored users to the eventfilter and // the syncreq itself for further use in streams. -func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { - ignores, err := p.DB.IgnoresForUser(ctx, req.Device.UserID) +func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error { + ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID) if err != nil { if err == sql.ErrNoRows { return nil diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 15db4d30..81cea7d5 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -23,20 +23,26 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/syncapi/notifier" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type PresenceStreamProvider struct { - StreamProvider + DefaultStreamProvider // cache contains previously sent presence updates to avoid unneeded updates cache sync.Map notifier *notifier.Notifier } -func (p *PresenceStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *PresenceStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForPresence(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForPresence(ctx) if err != nil { panic(err) } @@ -45,18 +51,20 @@ func (p *PresenceStreamProvider) Setup() { func (p *PresenceStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *PresenceStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { // We pull out a larger number than the filter asks for, since we're filtering out events later - presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) + presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000}) if err != nil { req.Log.WithError(err).Error("p.DB.PresenceAfter failed") return from @@ -84,7 +92,7 @@ func (p *PresenceStreamProvider) IncrementalSync( } // Bear in mind that this might return nil, but at least populating // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i]) + presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) if err != nil { req.Log.WithError(err).Error("unable to query presence for user") return from diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index f4e84c7d..8818a553 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -4,18 +4,24 @@ import ( "context" "encoding/json" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) type ReceiptStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *ReceiptStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *ReceiptStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForReceipts(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForReceipts(ctx) if err != nil { panic(err) } @@ -24,13 +30,15 @@ func (p *ReceiptStreamProvider) Setup() { func (p *ReceiptStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *ReceiptStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { @@ -41,7 +49,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( } } - lastPos, receipts, err := p.DB.RoomReceiptsAfter(ctx, joinedRooms, from) + lastPos, receipts, err := snapshot.RoomReceiptsAfter(ctx, joinedRooms, from) if err != nil { req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed") return from diff --git a/syncapi/streams/stream_sendtodevice.go b/syncapi/streams/stream_sendtodevice.go index 31c6187c..00b67cc4 100644 --- a/syncapi/streams/stream_sendtodevice.go +++ b/syncapi/streams/stream_sendtodevice.go @@ -3,17 +3,23 @@ package streams import ( "context" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type SendToDeviceStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *SendToDeviceStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *SendToDeviceStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForSendToDeviceMessages(ctx) if err != nil { panic(err) } @@ -22,18 +28,20 @@ func (p *SendToDeviceStreamProvider) Setup() { func (p *SendToDeviceStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *SendToDeviceStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { // See if we have any new tasks to do for the send-to-device messaging. - lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to) + lastPos, events, err := snapshot.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to) if err != nil { req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed") return from diff --git a/syncapi/streams/stream_typing.go b/syncapi/streams/stream_typing.go index f781065b..a6f7c7a0 100644 --- a/syncapi/streams/stream_typing.go +++ b/syncapi/streams/stream_typing.go @@ -5,24 +5,27 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) type TypingStreamProvider struct { - StreamProvider + DefaultStreamProvider EDUCache *caching.EDUCache } func (p *TypingStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *TypingStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition, ) types.StreamPosition { diff --git a/syncapi/streams/streamprovider.go b/syncapi/streams/streamprovider.go new file mode 100644 index 00000000..8b12e2eb --- /dev/null +++ b/syncapi/streams/streamprovider.go @@ -0,0 +1,28 @@ +package streams + +import ( + "context" + + "github.com/matrix-org/dendrite/syncapi/storage" + "github.com/matrix-org/dendrite/syncapi/types" +) + +type StreamProvider interface { + Setup(ctx context.Context, snapshot storage.DatabaseTransaction) + + // Advance will update the latest position of the stream based on + // an update and will wake callers waiting on StreamNotifyAfter. + Advance(latest types.StreamPosition) + + // CompleteSync will update the response to include all updates as needed + // for a complete sync. It will always return immediately. + CompleteSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest) types.StreamPosition + + // IncrementalSync will update the response to include all updates between + // the from and to sync positions. It will always return immediately, + // making no changes if the range contains no updates. + IncrementalSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition + + // LatestPosition returns the latest stream position for this stream. + LatestPosition(ctx context.Context) types.StreamPosition +} diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index dbc053bd..eccbb3a4 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -13,15 +13,15 @@ import ( ) type Streams struct { - PDUStreamProvider types.StreamProvider - TypingStreamProvider types.StreamProvider - ReceiptStreamProvider types.StreamProvider - InviteStreamProvider types.StreamProvider - SendToDeviceStreamProvider types.StreamProvider - AccountDataStreamProvider types.StreamProvider - DeviceListStreamProvider types.StreamProvider - NotificationDataStreamProvider types.StreamProvider - PresenceStreamProvider types.StreamProvider + PDUStreamProvider StreamProvider + TypingStreamProvider StreamProvider + ReceiptStreamProvider StreamProvider + InviteStreamProvider StreamProvider + SendToDeviceStreamProvider StreamProvider + AccountDataStreamProvider StreamProvider + DeviceListStreamProvider StreamProvider + NotificationDataStreamProvider StreamProvider + PresenceStreamProvider StreamProvider } func NewSyncStreamProviders( @@ -31,51 +31,58 @@ func NewSyncStreamProviders( ) *Streams { streams := &Streams{ PDUStreamProvider: &PDUStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - lazyLoadCache: lazyLoadCache, - rsAPI: rsAPI, - notifier: notifier, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + lazyLoadCache: lazyLoadCache, + rsAPI: rsAPI, + notifier: notifier, }, TypingStreamProvider: &TypingStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - EDUCache: eduCache, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + EDUCache: eduCache, }, ReceiptStreamProvider: &ReceiptStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, InviteStreamProvider: &InviteStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, AccountDataStreamProvider: &AccountDataStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - userAPI: userAPI, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + userAPI: userAPI, }, NotificationDataStreamProvider: &NotificationDataStreamProvider{ - StreamProvider: StreamProvider{DB: d}, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, }, DeviceListStreamProvider: &DeviceListStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - rsAPI: rsAPI, - keyAPI: keyAPI, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + rsAPI: rsAPI, + keyAPI: keyAPI, }, PresenceStreamProvider: &PresenceStreamProvider{ - StreamProvider: StreamProvider{DB: d}, - notifier: notifier, + DefaultStreamProvider: DefaultStreamProvider{DB: d}, + notifier: notifier, }, } - streams.PDUStreamProvider.Setup() - streams.TypingStreamProvider.Setup() - streams.ReceiptStreamProvider.Setup() - streams.InviteStreamProvider.Setup() - streams.SendToDeviceStreamProvider.Setup() - streams.AccountDataStreamProvider.Setup() - streams.NotificationDataStreamProvider.Setup() - streams.DeviceListStreamProvider.Setup() - streams.PresenceStreamProvider.Setup() + ctx := context.TODO() + snapshot, err := d.NewDatabaseSnapshot(ctx) + if err != nil { + panic(err) + } + defer snapshot.Rollback() // nolint:errcheck + + streams.PDUStreamProvider.Setup(ctx, snapshot) + streams.TypingStreamProvider.Setup(ctx, snapshot) + streams.ReceiptStreamProvider.Setup(ctx, snapshot) + streams.InviteStreamProvider.Setup(ctx, snapshot) + streams.SendToDeviceStreamProvider.Setup(ctx, snapshot) + streams.AccountDataStreamProvider.Setup(ctx, snapshot) + streams.NotificationDataStreamProvider.Setup(ctx, snapshot) + streams.DeviceListStreamProvider.Setup(ctx, snapshot) + streams.PresenceStreamProvider.Setup(ctx, snapshot) return streams } diff --git a/syncapi/streams/template_stream.go b/syncapi/streams/template_stream.go index 15074cc1..f208d84e 100644 --- a/syncapi/streams/template_stream.go +++ b/syncapi/streams/template_stream.go @@ -8,16 +8,18 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" ) -type StreamProvider struct { +type DefaultStreamProvider struct { DB storage.Database latest types.StreamPosition latestMutex sync.RWMutex } -func (p *StreamProvider) Setup() { +func (p *DefaultStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { } -func (p *StreamProvider) Advance( +func (p *DefaultStreamProvider) Advance( latest types.StreamPosition, ) { p.latestMutex.Lock() @@ -28,7 +30,7 @@ func (p *StreamProvider) Advance( } } -func (p *StreamProvider) LatestPosition( +func (p *DefaultStreamProvider) LatestPosition( ctx context.Context, ) types.StreamPosition { p.latestMutex.RLock() diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index b2ea105f..1d0ac1a4 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -305,6 +305,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately") } + snapshot, err := rp.db.NewDatabaseSnapshot(req.Context()) + if err != nil { + logrus.WithError(err).Error("Failed to acquire database snapshot for sync request") + return jsonerror.InternalServerError() + } + defer snapshot.Rollback() // nolint:errcheck + if syncReq.Since.IsEmpty() { // Complete sync syncReq.Response.NextBatch = types.StreamingToken{ @@ -312,70 +319,70 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // might advance while processing other streams, resulting in flakey // tests. DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), PDUPosition: rp.streams.PDUStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), TypingPosition: rp.streams.TypingStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), InvitePosition: rp.streams.InviteStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, ), } } else { // Incremental sync syncReq.Response.NextBatch = types.StreamingToken{ PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.PDUPosition, currentPos.PDUPosition, ), TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.TypingPosition, currentPos.TypingPosition, ), ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition, ), InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.InvitePosition, currentPos.InvitePosition, ), SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition, ), AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, ), NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition, ), DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, ), PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync( - syncReq.Context, syncReq, + syncReq.Context, snapshot, syncReq, syncReq.Since.PresencePosition, currentPos.PresencePosition, ), } @@ -437,9 +444,15 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed") return jsonerror.InternalServerError() } - rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition) + snapshot, err := rp.db.NewDatabaseSnapshot(req.Context()) + if err != nil { + logrus.WithError(err).Error("Failed to acquire database snapshot for key change") + return jsonerror.InternalServerError() + } + defer snapshot.Rollback() // nolint:errcheck + rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition) _, _, err = internal.DeviceListCatchup( - req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, + req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID, syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition, ) if err != nil { diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index a9ea234d..378cafe9 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -41,23 +41,3 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool { return false } } - -type StreamProvider interface { - Setup() - - // Advance will update the latest position of the stream based on - // an update and will wake callers waiting on StreamNotifyAfter. - Advance(latest StreamPosition) - - // CompleteSync will update the response to include all updates as needed - // for a complete sync. It will always return immediately. - CompleteSync(ctx context.Context, req *SyncRequest) StreamPosition - - // IncrementalSync will update the response to include all updates between - // the from and to sync positions. It will always return immediately, - // making no changes if the range contains no updates. - IncrementalSync(ctx context.Context, req *SyncRequest, from, to StreamPosition) StreamPosition - - // LatestPosition returns the latest stream position for this stream. - LatestPosition(ctx context.Context) StreamPosition -} |