aboutsummaryrefslogtreecommitdiff
path: root/syncapi/streams/stream_notificationdata.go
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/streams/stream_notificationdata.go')
-rw-r--r--syncapi/streams/stream_notificationdata.go20
1 files changed, 14 insertions, 6 deletions
diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go
index 33872734..5a81fd09 100644
--- a/syncapi/streams/stream_notificationdata.go
+++ b/syncapi/streams/stream_notificationdata.go
@@ -3,17 +3,23 @@ package streams
import (
"context"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type NotificationDataStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
}
-func (p *NotificationDataStreamProvider) Setup() {
- p.StreamProvider.Setup()
+func (p *NotificationDataStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
- id, err := p.DB.MaxStreamPositionForNotificationData(context.Background())
+ p.latestMutex.Lock()
+ defer p.latestMutex.Unlock()
+
+ id, err := snapshot.MaxStreamPositionForNotificationData(ctx)
if err != nil {
panic(err)
}
@@ -22,20 +28,22 @@ func (p *NotificationDataStreamProvider) Setup() {
func (p *NotificationDataStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, _ types.StreamPosition,
) types.StreamPosition {
// Get the unread notifications for rooms in our join response.
// This is to ensure clients always have an unread notification section
// and can display the correct numbers.
- countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
+ countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from