diff options
Diffstat (limited to 'syncapi/streams/stream_notificationdata.go')
-rw-r--r-- | syncapi/streams/stream_notificationdata.go | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go index 33872734..5a81fd09 100644 --- a/syncapi/streams/stream_notificationdata.go +++ b/syncapi/streams/stream_notificationdata.go @@ -3,17 +3,23 @@ package streams import ( "context" + "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) type NotificationDataStreamProvider struct { - StreamProvider + DefaultStreamProvider } -func (p *NotificationDataStreamProvider) Setup() { - p.StreamProvider.Setup() +func (p *NotificationDataStreamProvider) Setup( + ctx context.Context, snapshot storage.DatabaseTransaction, +) { + p.DefaultStreamProvider.Setup(ctx, snapshot) - id, err := p.DB.MaxStreamPositionForNotificationData(context.Background()) + p.latestMutex.Lock() + defer p.latestMutex.Unlock() + + id, err := snapshot.MaxStreamPositionForNotificationData(ctx) if err != nil { panic(err) } @@ -22,20 +28,22 @@ func (p *NotificationDataStreamProvider) Setup() { func (p *NotificationDataStreamProvider) CompleteSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, ) types.StreamPosition { - return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx)) + return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx)) } func (p *NotificationDataStreamProvider) IncrementalSync( ctx context.Context, + snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, _ types.StreamPosition, ) types.StreamPosition { // Get the unread notifications for rooms in our join response. // This is to ensure clients always have an unread notification section // and can display the correct numbers. - countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) + countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) if err != nil { req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed") return from |