aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-09-30 12:48:10 +0100
committerGitHub <noreply@github.com>2022-09-30 12:48:10 +0100
commit6348486a1365c7469a498101f5035a9b6bd16d22 (patch)
treed8a5ba572c5fc4fdec383802de5fac3a5e13c24d
parent8a82f100460dc5ca7bd39ae2345c251d6622c494 (diff)
Transactional isolation for `/sync` (#2745)
This should transactional snapshot isolation for `/sync` etc requests. For now we don't use repeatable read due to some odd test failures with invites.
-rw-r--r--syncapi/consumers/clientapi.go5
-rw-r--r--syncapi/consumers/keychange.go5
-rw-r--r--syncapi/consumers/presence.go5
-rw-r--r--syncapi/consumers/receipts.go5
-rw-r--r--syncapi/consumers/roomserver.go19
-rw-r--r--syncapi/consumers/sendtodevice.go5
-rw-r--r--syncapi/consumers/typing.go5
-rw-r--r--syncapi/consumers/userapi.go5
-rw-r--r--syncapi/internal/history_visibility.go2
-rw-r--r--syncapi/notifier/notifier.go19
-rw-r--r--syncapi/routing/context.go30
-rw-r--r--syncapi/routing/messages.go42
-rw-r--r--syncapi/routing/search.go24
-rw-r--r--syncapi/storage/interface.go117
-rw-r--r--syncapi/storage/postgres/invites_table.go22
-rw-r--r--syncapi/storage/shared/storage_consumer.go586
-rw-r--r--syncapi/storage/shared/storage_sync.go574
-rw-r--r--syncapi/storage/shared/syncserver.go1103
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go8
-rw-r--r--syncapi/storage/sqlite3/invites_table.go23
-rw-r--r--syncapi/storage/sqlite3/syncserver.go14
-rw-r--r--syncapi/storage/storage_test.go195
-rw-r--r--syncapi/storage/tables/interface.go2
-rw-r--r--syncapi/streams/stream_accountdata.go17
-rw-r--r--syncapi/streams/stream_devicelist.go7
-rw-r--r--syncapi/streams/stream_invite.go19
-rw-r--r--syncapi/streams/stream_notificationdata.go20
-rw-r--r--syncapi/streams/stream_pdu.go141
-rw-r--r--syncapi/streams/stream_presence.go22
-rw-r--r--syncapi/streams/stream_receipt.go20
-rw-r--r--syncapi/streams/stream_sendtodevice.go20
-rw-r--r--syncapi/streams/stream_typing.go7
-rw-r--r--syncapi/streams/streamprovider.go28
-rw-r--r--syncapi/streams/streams.go77
-rw-r--r--syncapi/streams/template_stream.go10
-rw-r--r--syncapi/sync/requestpool.go53
-rw-r--r--syncapi/types/provider.go20
37 files changed, 1754 insertions, 1522 deletions
diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go
index 796cc61e..735f6718 100644
--- a/syncapi/consumers/clientapi.go
+++ b/syncapi/consumers/clientapi.go
@@ -34,6 +34,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
)
@@ -46,7 +47,7 @@ type OutputClientDataConsumer struct {
topic string
topicReIndex string
db storage.Database
- stream types.StreamProvider
+ stream streams.StreamProvider
notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
fts *fulltext.Search
@@ -61,7 +62,7 @@ func NewOutputClientDataConsumer(
nats *nats.Conn,
store storage.Database,
notifier *notifier.Notifier,
- stream types.StreamProvider,
+ stream streams.StreamProvider,
fts *fulltext.Search,
) *OutputClientDataConsumer {
return &OutputClientDataConsumer{
diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go
index c42e7197..dc7d9e20 100644
--- a/syncapi/consumers/keychange.go
+++ b/syncapi/consumers/keychange.go
@@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
@@ -40,7 +41,7 @@ type OutputKeyChangeEventConsumer struct {
topic string
db storage.Database
notifier *notifier.Notifier
- stream types.StreamProvider
+ stream streams.StreamProvider
serverName gomatrixserverlib.ServerName // our server name
rsAPI roomserverAPI.SyncRoomserverAPI
}
@@ -55,7 +56,7 @@ func NewOutputKeyChangeEventConsumer(
rsAPI roomserverAPI.SyncRoomserverAPI,
store storage.Database,
notifier *notifier.Notifier,
- stream types.StreamProvider,
+ stream streams.StreamProvider,
) *OutputKeyChangeEventConsumer {
s := &OutputKeyChangeEventConsumer{
ctx: process.Context(),
diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go
index 61bdc13d..145059c2 100644
--- a/syncapi/consumers/presence.go
+++ b/syncapi/consumers/presence.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@@ -39,7 +40,7 @@ type PresenceConsumer struct {
requestTopic string
presenceTopic string
db storage.Database
- stream types.StreamProvider
+ stream streams.StreamProvider
notifier *notifier.Notifier
deviceAPI api.SyncUserAPI
cfg *config.SyncAPI
@@ -54,7 +55,7 @@ func NewPresenceConsumer(
nats *nats.Conn,
db storage.Database,
notifier *notifier.Notifier,
- stream types.StreamProvider,
+ stream streams.StreamProvider,
deviceAPI api.SyncUserAPI,
) *PresenceConsumer {
return &PresenceConsumer{
diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go
index 4379dd13..8aaa6573 100644
--- a/syncapi/consumers/receipts.go
+++ b/syncapi/consumers/receipts.go
@@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
)
@@ -38,7 +39,7 @@ type OutputReceiptEventConsumer struct {
durable string
topic string
db storage.Database
- stream types.StreamProvider
+ stream streams.StreamProvider
notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
}
@@ -51,7 +52,7 @@ func NewOutputReceiptEventConsumer(
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
- stream types.StreamProvider,
+ stream streams.StreamProvider,
) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{
ctx: process.Context(),
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index 3756ad75..e5e8fe29 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -33,6 +33,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
)
@@ -45,8 +46,8 @@ type OutputRoomEventConsumer struct {
durable string
topic string
db storage.Database
- pduStream types.StreamProvider
- inviteStream types.StreamProvider
+ pduStream streams.StreamProvider
+ inviteStream streams.StreamProvider
notifier *notifier.Notifier
fts *fulltext.Search
}
@@ -58,8 +59,8 @@ func NewOutputRoomEventConsumer(
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
- pduStream types.StreamProvider,
- inviteStream types.StreamProvider,
+ pduStream streams.StreamProvider,
+ inviteStream streams.StreamProvider,
rsAPI api.SyncRoomserverAPI,
fts *fulltext.Search,
) *OutputRoomEventConsumer {
@@ -449,8 +450,14 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.Head
}
stateKey := *event.StateKey()
- prevEvent, err := s.db.GetStateEvent(
- context.TODO(), event.RoomID(), event.Type(), stateKey,
+ snapshot, err := s.db.NewDatabaseSnapshot(s.ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+
+ prevEvent, err := snapshot.GetStateEvent(
+ s.ctx, event.RoomID(), event.Type(), stateKey,
)
if err != nil {
return event, err
diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go
index c0b43225..49d84cca 100644
--- a/syncapi/consumers/sendtodevice.go
+++ b/syncapi/consumers/sendtodevice.go
@@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
)
@@ -43,7 +44,7 @@ type OutputSendToDeviceEventConsumer struct {
db storage.Database
keyAPI keyapi.SyncKeyAPI
serverName gomatrixserverlib.ServerName // our server name
- stream types.StreamProvider
+ stream streams.StreamProvider
notifier *notifier.Notifier
}
@@ -56,7 +57,7 @@ func NewOutputSendToDeviceEventConsumer(
store storage.Database,
keyAPI keyapi.SyncKeyAPI,
notifier *notifier.Notifier,
- stream types.StreamProvider,
+ stream streams.StreamProvider,
) *OutputSendToDeviceEventConsumer {
return &OutputSendToDeviceEventConsumer{
ctx: process.Context(),
diff --git a/syncapi/consumers/typing.go b/syncapi/consumers/typing.go
index 88db80f8..67a26239 100644
--- a/syncapi/consumers/typing.go
+++ b/syncapi/consumers/typing.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
@@ -36,7 +37,7 @@ type OutputTypingEventConsumer struct {
durable string
topic string
eduCache *caching.EDUCache
- stream types.StreamProvider
+ stream streams.StreamProvider
notifier *notifier.Notifier
}
@@ -48,7 +49,7 @@ func NewOutputTypingEventConsumer(
js nats.JetStreamContext,
eduCache *caching.EDUCache,
notifier *notifier.Notifier,
- stream types.StreamProvider,
+ stream streams.StreamProvider,
) *OutputTypingEventConsumer {
return &OutputTypingEventConsumer{
ctx: process.Context(),
diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go
index c9b96f78..3c73dc1f 100644
--- a/syncapi/consumers/userapi.go
+++ b/syncapi/consumers/userapi.go
@@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
)
@@ -40,7 +41,7 @@ type OutputNotificationDataConsumer struct {
topic string
db storage.Database
notifier *notifier.Notifier
- stream types.StreamProvider
+ stream streams.StreamProvider
}
// NewOutputNotificationDataConsumer creates a new consumer. Call
@@ -51,7 +52,7 @@ func NewOutputNotificationDataConsumer(
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
- stream types.StreamProvider,
+ stream streams.StreamProvider,
) *OutputNotificationDataConsumer {
s := &OutputNotificationDataConsumer{
ctx: process.Context(),
diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go
index e73c004e..bbfe19f4 100644
--- a/syncapi/internal/history_visibility.go
+++ b/syncapi/internal/history_visibility.go
@@ -100,7 +100,7 @@ func (ev eventVisibility) allowed() (allowed bool) {
// Returns the filtered events and an error, if any.
func ApplyHistoryVisibilityFilter(
ctx context.Context,
- syncDB storage.Database,
+ syncDB storage.DatabaseTransaction,
rsAPI api.SyncRoomserverAPI,
events []*gomatrixserverlib.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{},
diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go
index 87f0d86d..a8e5bf9a 100644
--- a/syncapi/notifier/notifier.go
+++ b/syncapi/notifier/notifier.go
@@ -318,13 +318,20 @@ func (n *Notifier) GetListener(req types.SyncRequest) UserDeviceStreamListener {
func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
n.lock.Lock()
defer n.lock.Unlock()
- roomToUsers, err := db.AllJoinedUsersInRooms(ctx)
+
+ snapshot, err := db.NewDatabaseSnapshot(ctx)
+ if err != nil {
+ return err
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+
+ roomToUsers, err := snapshot.AllJoinedUsersInRooms(ctx)
if err != nil {
return err
}
n.setUsersJoinedToRooms(roomToUsers)
- roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx)
+ roomToPeekingDevices, err := snapshot.AllPeekingDevicesInRooms(ctx)
if err != nil {
return err
}
@@ -338,7 +345,13 @@ func (n *Notifier) LoadRooms(ctx context.Context, db storage.Database, roomIDs [
n.lock.Lock()
defer n.lock.Unlock()
- roomToUsers, err := db.AllJoinedUsersInRoom(ctx, roomIDs)
+ snapshot, err := db.NewDatabaseSnapshot(ctx)
+ if err != nil {
+ return err
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+
+ roomToUsers, err := snapshot.AllJoinedUsersInRoom(ctx, roomIDs)
if err != nil {
return err
}
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go
index 1ebdfe60..1ce34b85 100644
--- a/syncapi/routing/context.go
+++ b/syncapi/routing/context.go
@@ -51,6 +51,12 @@ func Context(
roomID, eventID string,
lazyLoadCache caching.LazyLoadCache,
) util.JSONResponse {
+ snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
+ if err != nil {
+ return jsonerror.InternalServerError()
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+
filter, err := parseRoomEventFilter(req)
if err != nil {
errMsg := ""
@@ -97,7 +103,7 @@ func Context(
ContainsURL: filter.ContainsURL,
}
- id, requestedEvent, err := syncDB.SelectContextEvent(ctx, roomID, eventID)
+ id, requestedEvent, err := snapshot.SelectContextEvent(ctx, roomID, eventID)
if err != nil {
if err == sql.ErrNoRows {
return util.JSONResponse{
@@ -111,7 +117,7 @@ func Context(
// verify the user is allowed to see the context for this room/event
startTime := time.Now()
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*gomatrixserverlib.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
@@ -127,20 +133,20 @@ func Context(
}
}
- eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, roomID, filter)
+ eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch before events")
return jsonerror.InternalServerError()
}
- _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, roomID, filter)
+ _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, roomID, filter)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("unable to fetch after events")
return jsonerror.InternalServerError()
}
startTime = time.Now()
- eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, syncDB, rsAPI, eventsBefore, eventsAfter, device.UserID)
+ eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return jsonerror.InternalServerError()
@@ -152,7 +158,7 @@ func Context(
}).Debug("applied history visibility (context eventsBefore/eventsAfter)")
// TODO: Get the actual state at the last event returned by SelectContextAfterEvent
- state, err := syncDB.CurrentState(ctx, roomID, &stateFilter, nil)
+ state, err := snapshot.CurrentState(ctx, roomID, &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to fetch current room state")
return jsonerror.InternalServerError()
@@ -173,7 +179,7 @@ func Context(
if len(response.State) > filter.Limit {
response.State = response.State[len(response.State)-filter.Limit:]
}
- start, end, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter)
+ start, end, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
if err == nil {
response.End = end.String()
response.Start = start.String()
@@ -188,7 +194,7 @@ func Context(
// by combining the events before and after the context event. Returns the filtered events,
// and an error, if any.
func applyHistoryVisibilityOnContextEvents(
- ctx context.Context, syncDB storage.Database, rsAPI roomserver.SyncRoomserverAPI,
+ ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*gomatrixserverlib.HeaderedEvent,
userID string,
) (filteredBefore, filteredAfter []*gomatrixserverlib.HeaderedEvent, err error) {
@@ -205,7 +211,7 @@ func applyHistoryVisibilityOnContextEvents(
}
allEvents := append(eventsBefore, eventsAfter...)
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, syncDB, rsAPI, allEvents, nil, userID, "context")
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, allEvents, nil, userID, "context")
if err != nil {
return nil, nil, err
}
@@ -222,15 +228,15 @@ func applyHistoryVisibilityOnContextEvents(
return filteredBefore, filteredAfter, nil
}
-func getStartEnd(ctx context.Context, syncDB storage.Database, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
+func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, startEvents, endEvents []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
if len(startEvents) > 0 {
- start, err = syncDB.EventPositionInTopology(ctx, startEvents[0].EventID())
+ start, err = snapshot.EventPositionInTopology(ctx, startEvents[0].EventID())
if err != nil {
return
}
}
if len(endEvents) > 0 {
- end, err = syncDB.EventPositionInTopology(ctx, endEvents[0].EventID())
+ end, err = snapshot.EventPositionInTopology(ctx, endEvents[0].EventID())
}
return
}
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index 03614302..8f3ed3f5 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
@@ -39,6 +40,7 @@ import (
type messagesReq struct {
ctx context.Context
db storage.Database
+ snapshot storage.DatabaseTransaction
rsAPI api.SyncRoomserverAPI
cfg *config.SyncAPI
roomID string
@@ -70,6 +72,16 @@ func OnIncomingMessagesRequest(
) util.JSONResponse {
var err error
+ // NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we
+ // expect to be able to write to the database in response to a /messages
+ // request that requires backfilling from the roomserver or federation.
+ snapshot, err := db.NewDatabaseTransaction(req.Context())
+ if err != nil {
+ return jsonerror.InternalServerError()
+ }
+ var succeeded bool
+ defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
+
// check if the user has already forgotten about this room
isForgotten, roomExists, err := checkIsRoomForgotten(req.Context(), roomID, device.UserID, rsAPI)
if err != nil {
@@ -132,7 +144,7 @@ func OnIncomingMessagesRequest(
}
} else {
fromStream = &streamToken
- from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
+ from, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
@@ -154,7 +166,7 @@ func OnIncomingMessagesRequest(
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
}
} else {
- to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
+ to, err = snapshot.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
@@ -165,7 +177,7 @@ func OnIncomingMessagesRequest(
// If "to" isn't provided, it defaults to either the earliest stream
// position (if we're going backward) or to the latest one (if we're
// going forward).
- to, err = setToDefault(req.Context(), db, backwardOrdering, roomID)
+ to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed")
return jsonerror.InternalServerError()
@@ -186,6 +198,7 @@ func OnIncomingMessagesRequest(
mReq := messagesReq{
ctx: req.Context(),
db: db,
+ snapshot: snapshot,
rsAPI: rsAPI,
cfg: cfg,
roomID: roomID,
@@ -217,7 +230,7 @@ func OnIncomingMessagesRequest(
Start: start.String(),
End: end.String(),
}
- res.applyLazyLoadMembers(req.Context(), db, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
+ res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache)
// If we didn't return any events, set the end to an empty string, so it will be omitted
// in the response JSON.
@@ -229,6 +242,7 @@ func OnIncomingMessagesRequest(
}
// Respond with the events.
+ succeeded = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: res,
@@ -239,7 +253,7 @@ func OnIncomingMessagesRequest(
// LazyLoadMembers enabled.
func (m *messagesResp) applyLazyLoadMembers(
ctx context.Context,
- db storage.Database,
+ db storage.DatabaseTransaction,
roomID string,
device *userapi.Device,
lazyLoad bool,
@@ -292,7 +306,7 @@ func (r *messagesReq) retrieveEvents() (
end types.TopologyToken, err error,
) {
// Retrieve the events from the local database.
- streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
+ streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err)
return
@@ -348,7 +362,7 @@ func (r *messagesReq) retrieveEvents() (
// Apply room history visibility filter
startTime := time.Now()
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.db, r.rsAPI, events, nil, r.device.UserID, "messages")
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
logrus.WithFields(logrus.Fields{
"duration": time.Since(startTime),
"room_id": r.roomID,
@@ -366,7 +380,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
// else to go. This seems to fix Element iOS from looping on /messages endlessly.
end = types.TopologyToken{}
} else {
- end, err = r.db.EventPositionInTopology(
+ end, err = r.snapshot.EventPositionInTopology(
r.ctx, events[0].EventID(),
)
// A stream/topological position is a cursor located between two events.
@@ -378,7 +392,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
}
} else {
start = *r.from
- end, err = r.db.EventPositionInTopology(
+ end, err = r.snapshot.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
)
}
@@ -399,7 +413,7 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
func (r *messagesReq) handleEmptyEventsSlice() (
events []*gomatrixserverlib.HeaderedEvent, err error,
) {
- backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
+ backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
// Check if we have backward extremities for this room.
if len(backwardExtremities) > 0 {
@@ -443,7 +457,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
}
// Check if the slice contains a backward extremity.
- backwardExtremities, err := r.db.BackwardExtremitiesForRoom(r.ctx, r.roomID)
+ backwardExtremities, err := r.snapshot.BackwardExtremitiesForRoom(r.ctx, r.roomID)
if err != nil {
return
}
@@ -463,7 +477,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
}
// Append the events ve previously retrieved locally.
- events = append(events, r.db.StreamEventsToEvents(nil, streamEvents)...)
+ events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...)
sort.Sort(eventsByDepth(events))
return
@@ -553,7 +567,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
// Returns an error if there was an issue with retrieving the latest position
// from the database
func setToDefault(
- ctx context.Context, db storage.Database, backwardOrdering bool,
+ ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool,
roomID string,
) (to types.TopologyToken, err error) {
if backwardOrdering {
@@ -561,7 +575,7 @@ func setToDefault(
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
to = types.TopologyToken{}
} else {
- to, err = db.MaxTopologicalPosition(ctx, roomID)
+ to, err = snapshot.MaxTopologicalPosition(ctx, roomID)
}
return
diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go
index 341efeb1..bac534a2 100644
--- a/syncapi/routing/search.go
+++ b/syncapi/routing/search.go
@@ -61,8 +61,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
searchReq.SearchCategories.RoomEvents.Filter.Limit = 5
}
+ snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
+ if err != nil {
+ return jsonerror.InternalServerError()
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+
// only search rooms the user is actually joined to
- joinedRooms, err := syncDB.RoomIDsWithMembership(ctx, device.UserID, "join")
+ joinedRooms, err := snapshot.RoomIDsWithMembership(ctx, device.UserID, "join")
if err != nil {
return jsonerror.InternalServerError()
}
@@ -161,12 +167,12 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
stateForRooms := make(map[string][]gomatrixserverlib.ClientEvent)
for _, event := range evs {
- eventsBefore, eventsAfter, err := contextEvents(ctx, syncDB, event, roomFilter, searchReq)
+ eventsBefore, eventsAfter, err := contextEvents(ctx, snapshot, event, roomFilter, searchReq)
if err != nil {
logrus.WithError(err).Error("failed to get context events")
return jsonerror.InternalServerError()
}
- startToken, endToken, err := getStartEnd(ctx, syncDB, eventsBefore, eventsAfter)
+ startToken, endToken, err := getStartEnd(ctx, snapshot, eventsBefore, eventsAfter)
if err != nil {
logrus.WithError(err).Error("failed to get start/end")
return jsonerror.InternalServerError()
@@ -176,7 +182,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
for _, ev := range append(eventsBefore, eventsAfter...) {
profile, ok := knownUsersProfiles[event.Sender()]
if !ok {
- stateEvent, err := syncDB.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender())
+ stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), gomatrixserverlib.MRoomMember, ev.Sender())
if err != nil {
logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile")
continue
@@ -209,7 +215,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
groups[event.RoomID()] = roomGroup
if _, ok := stateForRooms[event.RoomID()]; searchReq.SearchCategories.RoomEvents.IncludeState && !ok {
stateFilter := gomatrixserverlib.DefaultStateFilter()
- state, err := syncDB.CurrentState(ctx, event.RoomID(), &stateFilter, nil)
+ state, err := snapshot.CurrentState(ctx, event.RoomID(), &stateFilter, nil)
if err != nil {
logrus.WithError(err).Error("unable to get current state")
return jsonerror.InternalServerError()
@@ -252,24 +258,24 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
// contextEvents returns the events around a given eventID
func contextEvents(
ctx context.Context,
- syncDB storage.Database,
+ snapshot storage.DatabaseTransaction,
event *gomatrixserverlib.HeaderedEvent,
roomFilter *gomatrixserverlib.RoomEventFilter,
searchReq SearchRequest,
) ([]*gomatrixserverlib.HeaderedEvent, []*gomatrixserverlib.HeaderedEvent, error) {
- id, _, err := syncDB.SelectContextEvent(ctx, event.RoomID(), event.EventID())
+ id, _, err := snapshot.SelectContextEvent(ctx, event.RoomID(), event.EventID())
if err != nil {
logrus.WithError(err).Error("failed to query context event")
return nil, nil, err
}
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.BeforeLimit
- eventsBefore, err := syncDB.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter)
+ eventsBefore, err := snapshot.SelectContextBeforeEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil {
logrus.WithError(err).Error("failed to query before context event")
return nil, nil, err
}
roomFilter.Limit = searchReq.SearchCategories.RoomEvents.EventContext.AfterLimit
- _, eventsAfter, err := syncDB.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter)
+ _, eventsAfter, err := snapshot.SelectContextAfterEvent(ctx, id, event.RoomID(), roomFilter)
if err != nil {
logrus.WithError(err).Error("failed to query after context event")
return nil, nil, err
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index dd03365e..3732e43f 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -17,19 +17,17 @@ package storage
import (
"context"
- "github.com/matrix-org/dendrite/internal/eventutil"
-
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
-type Database interface {
- Presence
+type DatabaseTransaction interface {
SharedUsers
- Notifications
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
@@ -37,6 +35,7 @@ type Database interface {
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
+ MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@@ -44,21 +43,16 @@ type Database interface {
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error)
GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error)
-
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
-
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
-
- InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error)
+ InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error)
PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error)
-
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllJoinedUsersInRoom returns a map of room ID to a list of all joined user IDs for a given room.
AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error)
-
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
// Events lookups a list of event by their event ID.
@@ -67,16 +61,6 @@ type Database interface {
// Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
- // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
- // when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
- // Returns an error if there was a problem inserting this event.
- WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent,
- addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool,
- historyVisibility gomatrixserverlib.HistoryVisibility,
- ) (types.StreamPosition, error)
- // PurgeRoomState completely purges room state from the sync API. This is done when
- // receiving an output event that completely resets the state.
- PurgeRoomState(ctx context.Context, roomID string) error
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error
@@ -91,6 +75,61 @@ type Database interface {
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, types.StreamPosition, error)
+ // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
+ GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
+ // EventPositionInTopology returns the depth and stream position of the given event.
+ EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
+ // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
+ BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
+ // MaxTopologicalPosition returns the highest topological position for a given room.
+ MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
+ // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
+ // matches the streamevent.transactionID device then the transaction ID gets
+ // added to the unsigned section of the output event.
+ StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
+ // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
+ // relevant events within the given ranges for the supplied user ID and device ID.
+ SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
+ // GetRoomReceipts gets all receipts for a given roomID
+ GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
+ SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
+ SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
+ SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
+ StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
+ IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
+ // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
+ // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
+ // string as the membership.
+ SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
+ // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
+ GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
+ GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
+ PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
+}
+
+type Database interface {
+ Presence
+ Notifications
+
+ NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error)
+ NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error)
+
+ // Events lookups a list of event by their event ID.
+ // Returns a list of events matching the requested IDs found in the database.
+ // If an event is not found in the database then it will be omitted from the list.
+ // Returns an error if there was a problem talking with the database.
+ // Does not include any transaction IDs in the returned events.
+ Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
+ // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
+ // when generating the sync stream position for this event. Returns the sync stream position for the inserted event.
+ // Returns an error if there was a problem inserting this event.
+ WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []*gomatrixserverlib.HeaderedEvent,
+ addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool,
+ historyVisibility gomatrixserverlib.HistoryVisibility,
+ ) (types.StreamPosition, error)
+ // PurgeRoomState completely purges room state from the sync API. This is done when
+ // receiving an output event that completely resets the state.
+ PurgeRoomState(ctx context.Context, roomID string) error
// UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room)
@@ -114,21 +153,6 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
- // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
- GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
- // EventPositionInTopology returns the depth and stream position of the given event.
- EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
- // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
- BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
- // MaxTopologicalPosition returns the highest topological position for a given room.
- MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
- // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
- // matches the streamevent.transactionID device then the transaction ID gets
- // added to the unsigned section of the output event.
- StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
- // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
- // relevant events within the given ranges for the supplied user ID and device ID.
- SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
@@ -146,29 +170,13 @@ type Database interface {
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error
// StoreReceipt stores new receipt events
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
- // GetRoomReceipts gets all receipts for a given roomID
- GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error)
-
- SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
- SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
- SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
-
- StreamToTopologicalPosition(ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool) (types.TopologyToken, error)
-
- IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
- // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found
- // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty
- // string as the membership.
- SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error)
ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error)
}
type Presence interface {
- UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
- PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
- MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
+ UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error)
}
type SharedUsers interface {
@@ -179,7 +187,4 @@ type SharedUsers interface {
type Notifications interface {
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
-
- // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms
- GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
}
diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go
index f87ccf96..aada70d5 100644
--- a/syncapi/storage/postgres/invites_table.go
+++ b/syncapi/storage/postgres/invites_table.go
@@ -55,7 +55,7 @@ const deleteInviteEventSQL = "" +
"UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id"
const selectInviteEventsInRangeSQL = "" +
- "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
+ "SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC"
@@ -121,23 +121,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
// active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
-) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
+) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
+ var lastPos types.StreamPosition
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil {
- return nil, nil, err
+ return nil, nil, lastPos, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
result := map[string]*gomatrixserverlib.HeaderedEvent{}
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
for rows.Next() {
var (
+ id types.StreamPosition
roomID string
eventJSON []byte
deleted bool
)
- if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
- return nil, nil, err
+ if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil {
+ return nil, nil, lastPos, err
+ }
+ if id > lastPos {
+ lastPos = id
}
// if we have seen this room before, it has a higher stream position and hence takes priority
@@ -150,7 +155,7 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
var event *gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventJSON, &event); err != nil {
- return nil, nil, err
+ return nil, nil, lastPos, err
}
if deleted {
@@ -159,7 +164,10 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
result[roomID] = event
}
}
- return result, retired, rows.Err()
+ if lastPos == 0 {
+ lastPos = r.To
+ }
+ return result, retired, lastPos, rows.Err()
}
func (s *inviteEventsStatements) SelectMaxInviteID(
diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go
new file mode 100644
index 00000000..fb3b295e
--- /dev/null
+++ b/syncapi/storage/shared/storage_consumer.go
@@ -0,0 +1,586 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shared
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+
+ "github.com/tidwall/gjson"
+
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
+// For now this contains the shared functions
+type Database struct {
+ DB *sql.DB
+ Writer sqlutil.Writer
+ Invites tables.Invites
+ Peeks tables.Peeks
+ AccountData tables.AccountData
+ OutputEvents tables.Events
+ Topology tables.Topology
+ CurrentRoomState tables.CurrentRoomState
+ BackwardExtremities tables.BackwardsExtremities
+ SendToDevice tables.SendToDevice
+ Filter tables.Filter
+ Receipts tables.Receipts
+ Memberships tables.Memberships
+ NotificationData tables.NotificationData
+ Ignores tables.Ignores
+ Presence tables.Presence
+}
+
+func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
+ return d.NewDatabaseTransaction(ctx)
+
+ /*
+ TODO: Repeatable read is probably the right thing to do here,
+ but it seems to cause some problems with the invite tests, so
+ need to investigate that further.
+
+ txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{
+ // Set the isolation level so that we see a snapshot of the database.
+ // In PostgreSQL repeatable read transactions will see a snapshot taken
+ // at the first query, and since the transaction is read-only it can't
+ // run into any serialisation errors.
+ // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
+ Isolation: sql.LevelRepeatableRead,
+ ReadOnly: true,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &DatabaseTransaction{
+ Database: d,
+ txn: txn,
+ }, nil
+ */
+}
+
+func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) {
+ txn, err := d.DB.BeginTx(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ return &DatabaseTransaction{
+ Database: d,
+ txn: txn,
+ }, nil
+}
+
+func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
+ streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // We don't include a device here as we only include transaction IDs in
+ // incremental syncs.
+ return d.StreamEventsToEvents(nil, streamEvents), nil
+}
+
+// AddInviteEvent stores a new invite event for a user.
+// If the invite was successfully stored this returns the stream ID it was stored at.
+// Returns an error if there was a problem communicating with the database.
+func (d *Database) AddInviteEvent(
+ ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent,
+) (sp types.StreamPosition, err error) {
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
+ return err
+ })
+ return
+}
+
+// RetireInviteEvent removes an old invite event from the database.
+// Returns an error if there was a problem communicating with the database.
+func (d *Database) RetireInviteEvent(
+ ctx context.Context, inviteEventID string,
+) (sp types.StreamPosition, err error) {
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID)
+ return err
+ })
+ return
+}
+
+// AddPeek tracks the fact that a user has started peeking.
+// If the peek was successfully stored this returns the stream ID it was stored at.
+// Returns an error if there was a problem communicating with the database.
+func (d *Database) AddPeek(
+ ctx context.Context, roomID, userID, deviceID string,
+) (sp types.StreamPosition, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID)
+ return err
+ })
+ return
+}
+
+// DeletePeek tracks the fact that a user has stopped peeking from the specified
+// device. If the peeks was successfully deleted this returns the stream ID it was
+// stored at. Returns an error if there was a problem communicating with the database.
+func (d *Database) DeletePeek(
+ ctx context.Context, roomID, userID, deviceID string,
+) (sp types.StreamPosition, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID)
+ return err
+ })
+ if err == sql.ErrNoRows {
+ sp = 0
+ err = nil
+ }
+ return
+}
+
+// DeletePeeks tracks the fact that a user has stopped peeking from all devices
+// If the peeks was successfully deleted this returns the stream ID it was stored at.
+// Returns an error if there was a problem communicating with the database.
+func (d *Database) DeletePeeks(
+ ctx context.Context, roomID, userID string,
+) (sp types.StreamPosition, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID)
+ return err
+ })
+ if err == sql.ErrNoRows {
+ sp = 0
+ err = nil
+ }
+ return
+}
+
+// UpsertAccountData keeps track of new or updated account data, by saving the type
+// of the new/updated data, and the user ID and room ID the data is related to (empty)
+// room ID means the data isn't specific to any room)
+// If no data with the given type, user ID and room ID exists in the database,
+// creates a new row, else update the existing one
+// Returns an error if there was an issue with the upsert
+func (d *Database) UpsertAccountData(
+ ctx context.Context, userID, roomID, dataType string,
+) (sp types.StreamPosition, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
+ return err
+ })
+ return
+}
+
+func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent {
+ out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
+ for i := 0; i < len(in); i++ {
+ out[i] = in[i].HeaderedEvent
+ if device != nil && in[i].TransactionID != nil {
+ if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
+ err := out[i].SetUnsignedField(
+ "transaction_id", in[i].TransactionID.TransactionID,
+ )
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ "event_id": out[i].EventID(),
+ }).WithError(err).Warnf("Failed to add transaction ID to event")
+ }
+ }
+ }
+ }
+ return out
+}
+
+// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
+// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
+// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
+// This function should always be called within a sqlutil.Writer for safety in SQLite.
+func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
+ if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+
+ // Check if we have all of the event's previous events. If an event is
+ // missing, add it to the room's backward extremities.
+ prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
+ if err != nil {
+ return err
+ }
+ var found bool
+ for _, eID := range ev.PrevEventIDs() {
+ found = false
+ for _, prevEv := range prevEvents {
+ if eID == prevEv.EventID() {
+ found = true
+ }
+ }
+
+ // If the event is missing, consider it a backward extremity.
+ if !found {
+ if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func (d *Database) PurgeRoomState(
+ ctx context.Context, roomID string,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ // If the event is a create event then we'll delete all of the existing
+ // data for the room. The only reason that a create event would be replayed
+ // to us in this way is if we're about to receive the entire room state.
+ if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
+ return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
+ }
+ return nil
+ })
+}
+
+func (d *Database) WriteEvent(
+ ctx context.Context,
+ ev *gomatrixserverlib.HeaderedEvent,
+ addStateEvents []*gomatrixserverlib.HeaderedEvent,
+ addStateEventIDs, removeStateEventIDs []string,
+ transactionID *api.TransactionID, excludeFromSync bool,
+ historyVisibility gomatrixserverlib.HistoryVisibility,
+) (pduPosition types.StreamPosition, returnErr error) {
+ returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ var err error
+ ev.Visibility = historyVisibility
+ pos, err := d.OutputEvents.InsertEvent(
+ ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
+ )
+ if err != nil {
+ return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
+ }
+ pduPosition = pos
+ var topoPosition types.StreamPosition
+ if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
+ return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
+ }
+
+ if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
+ return fmt.Errorf("d.handleBackwardExtremities: %w", err)
+ }
+
+ if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
+ // Nothing to do, the event may have just been a message event.
+ return nil
+ }
+ for i := range addStateEvents {
+ addStateEvents[i].Visibility = historyVisibility
+ }
+ return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition)
+ })
+
+ return pduPosition, returnErr
+}
+
+// This function should always be called within a sqlutil.Writer for safety in SQLite.
+func (d *Database) updateRoomState(
+ ctx context.Context, txn *sql.Tx,
+ removedEventIDs []string,
+ addedEvents []*gomatrixserverlib.HeaderedEvent,
+ pduPosition types.StreamPosition,
+ topoPosition types.StreamPosition,
+) error {
+ // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
+ for _, eventID := range removedEventIDs {
+ if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
+ return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
+ }
+ }
+
+ for _, event := range addedEvents {
+ if event.StateKey() == nil {
+ // ignore non state events
+ continue
+ }
+ var membership *string
+ if event.Type() == "m.room.member" {
+ value, err := event.Membership()
+ if err != nil {
+ return fmt.Errorf("event.Membership: %w", err)
+ }
+ membership = &value
+ if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil {
+ return fmt.Errorf("d.Memberships.UpsertMembership: %w", err)
+ }
+ }
+
+ if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
+ return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func (d *Database) GetFilter(
+ ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
+) error {
+ return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID)
+}
+
+func (d *Database) PutFilter(
+ ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
+) (string, error) {
+ var filterID string
+ var err error
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart)
+ return err
+ })
+ return filterID, err
+}
+
+func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error {
+ redactedEvents, err := d.Events(ctx, []string{redactedEventID})
+ if err != nil {
+ return err
+ }
+ if len(redactedEvents) == 0 {
+ logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction")
+ return nil
+ }
+ eventToRedact := redactedEvents[0].Unwrap()
+ redactionEvent := redactedBecause.Unwrap()
+ if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
+ return err
+ }
+
+ newEvent := eventToRedact.Headered(redactedBecause.RoomVersion)
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
+ })
+ return err
+}
+
+// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
+// Returns a map of room ID to list of events.
+func (d *Database) fetchStateEvents(
+ ctx context.Context, txn *sql.Tx,
+ roomIDToEventIDSet map[string]map[string]bool,
+ eventIDToEvent map[string]types.StreamEvent,
+) (map[string][]types.StreamEvent, error) {
+ stateBetween := make(map[string][]types.StreamEvent)
+ missingEvents := make(map[string][]string)
+ for roomID, ids := range roomIDToEventIDSet {
+ events := stateBetween[roomID]
+ for id, need := range ids {
+ if !need {
+ continue // deleted state
+ }
+ e, ok := eventIDToEvent[id]
+ if ok {
+ events = append(events, e)
+ } else {
+ m := missingEvents[roomID]
+ m = append(m, id)
+ missingEvents[roomID] = m
+ }
+ }
+ stateBetween[roomID] = events
+ }
+
+ if len(missingEvents) > 0 {
+ // This happens when add_state_ids has an event ID which is not in the provided range.
+ // We need to explicitly fetch them.
+ allMissingEventIDs := []string{}
+ for _, missingEvIDs := range missingEvents {
+ allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
+ }
+ evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
+ if err != nil {
+ return nil, err
+ }
+ // we know we got them all otherwise an error would've been returned, so just loop the events
+ for _, ev := range evs {
+ roomID := ev.RoomID()
+ stateBetween[roomID] = append(stateBetween[roomID], ev)
+ }
+ }
+ return stateBetween, nil
+}
+
+func (d *Database) fetchMissingStateEvents(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) ([]types.StreamEvent, error) {
+ // Fetch from the events table first so we pick up the stream ID for the
+ // event.
+ events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
+ if err != nil {
+ return nil, err
+ }
+
+ have := map[string]bool{}
+ for _, event := range events {
+ have[event.EventID()] = true
+ }
+ var missing []string
+ for _, eventID := range eventIDs {
+ if !have[eventID] {
+ missing = append(missing, eventID)
+ }
+ }
+ if len(missing) == 0 {
+ return events, nil
+ }
+
+ // If they are missing from the events table then they should be state
+ // events that we received from outside the main event stream.
+ // These should be in the room state table.
+ stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing)
+
+ if err != nil {
+ return nil, err
+ }
+ if len(stateEvents) != len(missing) {
+ logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing))
+
+ // TODO: Why is this happening? It's probably the roomserver. Uncomment
+ // this error again when we work out what it is and fix it, otherwise we
+ // just end up returning lots of 500s to the client and that breaks
+ // pretty much everything, rather than just sending what we have.
+ //return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
+ }
+ events = append(events, stateEvents...)
+ return events, nil
+}
+
+func (d *Database) StoreNewSendForDeviceMessage(
+ ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
+) (newPos types.StreamPosition, err error) {
+ j, err := json.Marshal(event)
+ if err != nil {
+ return 0, err
+ }
+ // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
+ // that we don't lock the table for writes in more than one place.
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
+ ctx, txn, userID, deviceID, string(j),
+ )
+ return err
+ })
+ if err != nil {
+ return 0, err
+ }
+ return newPos, nil
+}
+
+func (d *Database) CleanSendToDeviceUpdates(
+ ctx context.Context,
+ userID, deviceID string, before types.StreamPosition,
+) (err error) {
+ if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
+ }); err != nil {
+ logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
+ return err
+ }
+ return nil
+}
+
+// getMembershipFromEvent returns the value of content.membership iff the event is a state event
+// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
+func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) {
+ if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
+ return "", ""
+ }
+ membership, err := ev.Membership()
+ if err != nil {
+ return "", ""
+ }
+ prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str
+ return membership, prevMembership
+}
+
+// StoreReceipt stores user receipts
+func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp)
+ return err
+ })
+ return
+}
+
+func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount)
+ return err
+ })
+ return
+}
+
+func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
+ return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
+}
+
+func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
+ return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
+}
+func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
+ return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
+}
+
+func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
+ return d.Ignores.SelectIgnores(ctx, nil, userID)
+}
+
+func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores)
+ })
+}
+
+func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
+ var pos types.StreamPosition
+ var err error
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync)
+ return nil
+ })
+ return pos, err
+}
+
+func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
+ return d.Presence.GetPresenceForUser(ctx, nil, userID)
+}
+
+func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
+ return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
+}
+
+func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
+ return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
+ gomatrixserverlib.MRoomName,
+ gomatrixserverlib.MRoomTopic,
+ "m.room.message",
+ })
+}
diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go
new file mode 100644
index 00000000..a19135a6
--- /dev/null
+++ b/syncapi/storage/shared/storage_sync.go
@@ -0,0 +1,574 @@
+package shared
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type DatabaseTransaction struct {
+ *Database
+ txn *sql.Tx
+}
+
+func (d *DatabaseTransaction) Commit() error {
+ if d.txn == nil {
+ return nil
+ }
+ return d.txn.Commit()
+}
+
+func (d *DatabaseTransaction) Rollback() error {
+ if d.txn == nil {
+ return nil
+ }
+ return d.txn.Rollback()
+}
+
+func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
+ id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn)
+ if err != nil {
+ return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
+ }
+ return types.StreamPosition(id), nil
+}
+
+func (d *DatabaseTransaction) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
+ id, err := d.Receipts.SelectMaxReceiptID(ctx, d.txn)
+ if err != nil {
+ return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
+ }
+ return types.StreamPosition(id), nil
+}
+
+func (d *DatabaseTransaction) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
+ id, err := d.Invites.SelectMaxInviteID(ctx, d.txn)
+ if err != nil {
+ return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
+ }
+ return types.StreamPosition(id), nil
+}
+
+func (d *DatabaseTransaction) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
+ id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, d.txn)
+ if err != nil {
+ return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
+ }
+ return types.StreamPosition(id), nil
+}
+
+func (d *DatabaseTransaction) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
+ id, err := d.AccountData.SelectMaxAccountDataID(ctx, d.txn)
+ if err != nil {
+ return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
+ }
+ return types.StreamPosition(id), nil
+}
+
+func (d *DatabaseTransaction) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
+ id, err := d.NotificationData.SelectMaxID(ctx, d.txn)
+ if err != nil {
+ return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
+ }
+ return types.StreamPosition(id), nil
+}
+
+func (d *DatabaseTransaction) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
+ return d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilterPart, excludeEventIDs)
+}
+
+func (d *DatabaseTransaction) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
+ return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, d.txn, userID, membership)
+}
+
+func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
+ return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos)
+}
+
+func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
+ return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships)
+}
+
+func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
+ return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
+}
+
+func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
+ return d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
+}
+
+func (d *DatabaseTransaction) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
+ return d.Invites.SelectInviteEventsInRange(ctx, d.txn, targetUserID, r)
+}
+
+func (d *DatabaseTransaction) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
+ return d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, deviceID, r)
+}
+
+func (d *DatabaseTransaction) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
+ return d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
+}
+
+// Events lookups a list of event by their event ID.
+// Returns a list of events matching the requested IDs found in the database.
+// If an event is not found in the database then it will be omitted from the list.
+// Returns an error if there was a problem talking with the database.
+// Does not include any transaction IDs in the returned events.
+func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
+ streamEvents, err := d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, false)
+ if err != nil {
+ return nil, err
+ }
+
+ // We don't include a device here as we only include transaction IDs in
+ // incremental syncs.
+ return d.StreamEventsToEvents(nil, streamEvents), nil
+}
+
+func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
+ return d.CurrentRoomState.SelectJoinedUsers(ctx, d.txn)
+}
+
+func (d *DatabaseTransaction) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
+ return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, d.txn, roomIDs)
+}
+
+func (d *DatabaseTransaction) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
+ return d.Peeks.SelectPeekingDevices(ctx, d.txn)
+}
+
+func (d *DatabaseTransaction) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
+ return d.CurrentRoomState.SelectSharedUsers(ctx, d.txn, userID, otherUserIDs)
+}
+
+func (d *DatabaseTransaction) GetStateEvent(
+ ctx context.Context, roomID, evType, stateKey string,
+) (*gomatrixserverlib.HeaderedEvent, error) {
+ return d.CurrentRoomState.SelectStateEvent(ctx, d.txn, roomID, evType, stateKey)
+}
+
+func (d *DatabaseTransaction) GetStateEventsForRoom(
+ ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
+) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
+ stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
+ return
+}
+
+// GetAccountDataInRange returns all account data for a given user inserted or
+// updated between two given positions
+// Returns a map following the format data[roomID] = []dataTypes
+// If no data is retrieved, returns an empty map
+// If there was an issue with the retrieval, returns an error
+func (d *DatabaseTransaction) GetAccountDataInRange(
+ ctx context.Context, userID string, r types.Range,
+ accountDataFilterPart *gomatrixserverlib.EventFilter,
+) (map[string][]string, types.StreamPosition, error) {
+ return d.AccountData.SelectAccountDataInRange(ctx, d.txn, userID, r, accountDataFilterPart)
+}
+
+func (d *DatabaseTransaction) GetEventsInTopologicalRange(
+ ctx context.Context,
+ from, to *types.TopologyToken,
+ roomID string,
+ filter *gomatrixserverlib.RoomEventFilter,
+ backwardOrdering bool,
+) (events []types.StreamEvent, err error) {
+ var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
+ if backwardOrdering {
+ // Backward ordering means the 'from' token has a higher depth than the 'to' token
+ minDepth = to.Depth
+ maxDepth = from.Depth
+ // for cases where we have say 5 events with the same depth, the TopologyToken needs to
+ // know which of the 5 the client has seen. This is done by using the PDU position.
+ // Events with the same maxDepth but less than this PDU position will be returned.
+ maxStreamPosForMaxDepth = from.PDUPosition
+ } else {
+ // Forward ordering means the 'from' token has a lower depth than the 'to' token.
+ minDepth = from.Depth
+ maxDepth = to.Depth
+ }
+
+ // Select the event IDs from the defined range.
+ var eIDs []string
+ eIDs, err = d.Topology.SelectEventIDsInRange(
+ ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
+ )
+ if err != nil {
+ return
+ }
+
+ // Retrieve the events' contents using their IDs.
+ events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true)
+ return
+}
+
+func (d *DatabaseTransaction) BackwardExtremitiesForRoom(
+ ctx context.Context, roomID string,
+) (backwardExtremities map[string][]string, err error) {
+ return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID)
+}
+
+func (d *DatabaseTransaction) MaxTopologicalPosition(
+ ctx context.Context, roomID string,
+) (types.TopologyToken, error) {
+ depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
+ if err != nil {
+ return types.TopologyToken{}, err
+ }
+ return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
+}
+
+func (d *DatabaseTransaction) EventPositionInTopology(
+ ctx context.Context, eventID string,
+) (types.TopologyToken, error) {
+ depth, stream, err := d.Topology.SelectPositionInTopology(ctx, d.txn, eventID)
+ if err != nil {
+ return types.TopologyToken{}, err
+ }
+ return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
+}
+
+func (d *DatabaseTransaction) StreamToTopologicalPosition(
+ ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
+) (types.TopologyToken, error) {
+ topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, d.txn, roomID, streamPos, backwardOrdering)
+ switch {
+ case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
+ return types.TopologyToken{PDUPosition: streamPos}, nil
+ case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
+ topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
+ if err != nil {
+ return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
+ }
+ return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
+ case err != nil: // some other error happened
+ return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
+ default:
+ return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
+ }
+}
+
+// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
+// oldest event in the room's topology.
+func (d *DatabaseTransaction) GetBackwardTopologyPos(
+ ctx context.Context,
+ events []types.StreamEvent,
+) (types.TopologyToken, error) {
+ zeroToken := types.TopologyToken{}
+ if len(events) == 0 {
+ return zeroToken, nil
+ }
+ pos, spos, err := d.Topology.SelectPositionInTopology(ctx, d.txn, events[0].EventID())
+ if err != nil {
+ return zeroToken, err
+ }
+ tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
+ tok.Decrement()
+ return tok, nil
+}
+
+// GetStateDeltas returns the state deltas between fromPos and toPos,
+// exclusive of oldPos, inclusive of newPos, for the rooms in which
+// the user has new membership events.
+// A list of joined room IDs is also returned in case the caller needs it.
+func (d *DatabaseTransaction) GetStateDeltas(
+ ctx context.Context, device *userapi.Device,
+ r types.Range, userID string,
+ stateFilter *gomatrixserverlib.StateFilter,
+) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
+ // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
+ // - Get membership list changes for this user in this sync response
+ // - For each room which has membership list changes:
+ // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
+ // If it is, then we need to send the full room state down (and 'limited' is always true).
+ // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
+ // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
+ // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
+
+ // Look up all memberships for the user. We only care about rooms that a
+ // user has ever interacted with — joined to, kicked/banned from, left.
+ memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil, nil
+ }
+ return nil, nil, err
+ }
+
+ allRoomIDs := make([]string, 0, len(memberships))
+ joinedRoomIDs := make([]string, 0, len(memberships))
+ for roomID, membership := range memberships {
+ allRoomIDs = append(allRoomIDs, roomID)
+ if membership == gomatrixserverlib.Join {
+ joinedRoomIDs = append(joinedRoomIDs, roomID)
+ }
+ }
+
+ // get all the state events ever (i.e. for all available rooms) between these two positions
+ stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil, nil
+ }
+ return nil, nil, err
+ }
+ state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil, nil
+ }
+ return nil, nil, err
+ }
+
+ // find out which rooms this user is peeking, if any.
+ // We do this before joins so any peeks get overwritten
+ peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
+ if err != nil && err != sql.ErrNoRows {
+ return nil, nil, err
+ }
+
+ // add peek blocks
+ for _, peek := range peeks {
+ if peek.New {
+ // send full room state down instead of a delta
+ var s []types.StreamEvent
+ s, err = d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ continue
+ }
+ return nil, nil, err
+ }
+ state[peek.RoomID] = s
+ }
+ if !peek.Deleted {
+ deltas = append(deltas, types.StateDelta{
+ Membership: gomatrixserverlib.Peek,
+ StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]),
+ RoomID: peek.RoomID,
+ })
+ }
+ }
+
+ // handle newly joined rooms and non-joined rooms
+ newlyJoinedRooms := make(map[string]bool, len(state))
+ for roomID, stateStreamEvents := range state {
+ for _, ev := range stateStreamEvents {
+ if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
+ if membership == gomatrixserverlib.Join && prevMembership != membership {
+ // send full room state down instead of a delta
+ var s []types.StreamEvent
+ s, err = d.currentStateStreamEventsForRoom(ctx, roomID, stateFilter)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ continue
+ }
+ return nil, nil, err
+ }
+ state[roomID] = s
+ newlyJoinedRooms[roomID] = true
+ continue // we'll add this room in when we do joined rooms
+ }
+
+ deltas = append(deltas, types.StateDelta{
+ Membership: membership,
+ MembershipPos: ev.StreamPosition,
+ StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
+ RoomID: roomID,
+ })
+ break
+ }
+ }
+ }
+
+ // Add in currently joined rooms
+ for _, joinedRoomID := range joinedRoomIDs {
+ deltas = append(deltas, types.StateDelta{
+ Membership: gomatrixserverlib.Join,
+ StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
+ RoomID: joinedRoomID,
+ NewlyJoined: newlyJoinedRooms[joinedRoomID],
+ })
+ }
+
+ return deltas, joinedRoomIDs, nil
+}
+
+// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
+// requests with full_state=true.
+// Fetches full state for all joined rooms and uses selectStateInRange to get
+// updates for other rooms.
+func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
+ ctx context.Context, device *userapi.Device,
+ r types.Range, userID string,
+ stateFilter *gomatrixserverlib.StateFilter,
+) ([]types.StateDelta, []string, error) {
+ // Look up all memberships for the user. We only care about rooms that a
+ // user has ever interacted with — joined to, kicked/banned from, left.
+ memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, d.txn, userID)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil, nil
+ }
+ return nil, nil, err
+ }
+
+ allRoomIDs := make([]string, 0, len(memberships))
+ joinedRoomIDs := make([]string, 0, len(memberships))
+ for roomID, membership := range memberships {
+ allRoomIDs = append(allRoomIDs, roomID)
+ if membership == gomatrixserverlib.Join {
+ joinedRoomIDs = append(joinedRoomIDs, roomID)
+ }
+ }
+
+ // Use a reasonable initial capacity
+ deltas := make(map[string]types.StateDelta)
+
+ peeks, err := d.Peeks.SelectPeeksInRange(ctx, d.txn, userID, device.ID, r)
+ if err != nil && err != sql.ErrNoRows {
+ return nil, nil, err
+ }
+
+ // Add full states for all peeking rooms
+ for _, peek := range peeks {
+ if !peek.Deleted {
+ s, stateErr := d.currentStateStreamEventsForRoom(ctx, peek.RoomID, stateFilter)
+ if stateErr != nil {
+ if stateErr == sql.ErrNoRows {
+ continue
+ }
+ return nil, nil, stateErr
+ }
+ deltas[peek.RoomID] = types.StateDelta{
+ Membership: gomatrixserverlib.Peek,
+ StateEvents: d.StreamEventsToEvents(device, s),
+ RoomID: peek.RoomID,
+ }
+ }
+ }
+
+ // Get all the state events ever between these two positions
+ stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil, nil
+ }
+ return nil, nil, err
+ }
+ state, err := d.fetchStateEvents(ctx, d.txn, stateNeeded, eventMap)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil, nil
+ }
+ return nil, nil, err
+ }
+
+ for roomID, stateStreamEvents := range state {
+ for _, ev := range stateStreamEvents {
+ if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" {
+ if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
+ deltas[roomID] = types.StateDelta{
+ Membership: membership,
+ MembershipPos: ev.StreamPosition,
+ StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
+ RoomID: roomID,
+ }
+ }
+
+ break
+ }
+ }
+ }
+
+ // Add full states for all joined rooms
+ for _, joinedRoomID := range joinedRoomIDs {
+ s, stateErr := d.currentStateStreamEventsForRoom(ctx, joinedRoomID, stateFilter)
+ if stateErr != nil {
+ if stateErr == sql.ErrNoRows {
+ continue
+ }
+ return nil, nil, stateErr
+ }
+ deltas[joinedRoomID] = types.StateDelta{
+ Membership: gomatrixserverlib.Join,
+ StateEvents: d.StreamEventsToEvents(device, s),
+ RoomID: joinedRoomID,
+ }
+ }
+
+ // Create a response array.
+ result := make([]types.StateDelta, len(deltas))
+ i := 0
+ for _, delta := range deltas {
+ result[i] = delta
+ i++
+ }
+
+ return result, joinedRoomIDs, nil
+}
+
+func (d *DatabaseTransaction) currentStateStreamEventsForRoom(
+ ctx context.Context, roomID string,
+ stateFilter *gomatrixserverlib.StateFilter,
+) ([]types.StreamEvent, error) {
+ allState, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, stateFilter, nil)
+ if err != nil {
+ return nil, err
+ }
+ s := make([]types.StreamEvent, len(allState))
+ for i := 0; i < len(s); i++ {
+ s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0}
+ }
+ return s, nil
+}
+
+func (d *DatabaseTransaction) SendToDeviceUpdatesForSync(
+ ctx context.Context,
+ userID, deviceID string,
+ from, to types.StreamPosition,
+) (types.StreamPosition, []types.SendToDeviceEvent, error) {
+ // First of all, get our send-to-device updates for this user.
+ lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, d.txn, userID, deviceID, from, to)
+ if err != nil {
+ return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
+ }
+ // If there's nothing to do then stop here.
+ if len(events) == 0 {
+ return to, nil, nil
+ }
+ return lastPos, events, nil
+}
+
+func (d *DatabaseTransaction) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
+ _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, d.txn, roomIDs, streamPos)
+ return receipts, err
+}
+
+func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
+ roomIDs := make([]string, 0, len(rooms))
+ for roomID, membership := range rooms {
+ if membership != gomatrixserverlib.Join {
+ continue
+ }
+ roomIDs = append(roomIDs, roomID)
+ }
+ return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs)
+}
+
+func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
+ return d.Presence.GetPresenceForUser(ctx, d.txn, userID)
+}
+
+func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
+ return d.Presence.GetPresenceAfter(ctx, d.txn, after, filter)
+}
+
+func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
+ return d.Presence.GetMaxPresenceID(ctx, d.txn)
+}
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
deleted file mode 100644
index a05e6880..00000000
--- a/syncapi/storage/shared/syncserver.go
+++ /dev/null
@@ -1,1103 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package shared
-
-import (
- "context"
- "database/sql"
- "encoding/json"
- "fmt"
-
- "github.com/tidwall/gjson"
-
- userapi "github.com/matrix-org/dendrite/userapi/api"
-
- "github.com/matrix-org/gomatrixserverlib"
- "github.com/sirupsen/logrus"
-
- "github.com/matrix-org/dendrite/internal/eventutil"
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/syncapi/storage/tables"
- "github.com/matrix-org/dendrite/syncapi/types"
-)
-
-// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
-// For now this contains the shared functions
-type Database struct {
- DB *sql.DB
- Writer sqlutil.Writer
- Invites tables.Invites
- Peeks tables.Peeks
- AccountData tables.AccountData
- OutputEvents tables.Events
- Topology tables.Topology
- CurrentRoomState tables.CurrentRoomState
- BackwardExtremities tables.BackwardsExtremities
- SendToDevice tables.SendToDevice
- Filter tables.Filter
- Receipts tables.Receipts
- Memberships tables.Memberships
- NotificationData tables.NotificationData
- Ignores tables.Ignores
- Presence tables.Presence
-}
-
-func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) {
- return d.DB.BeginTx(ctx, &sql.TxOptions{
- // Set the isolation level so that we see a snapshot of the database.
- // In PostgreSQL repeatable read transactions will see a snapshot taken
- // at the first query, and since the transaction is read-only it can't
- // run into any serialisation errors.
- // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ
- Isolation: sql.LevelRepeatableRead,
- ReadOnly: true,
- })
-}
-
-func (d *Database) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
- id, err := d.OutputEvents.SelectMaxEventID(ctx, nil)
- if err != nil {
- return 0, fmt.Errorf("d.OutputEvents.SelectMaxEventID: %w", err)
- }
- return types.StreamPosition(id), nil
-}
-
-func (d *Database) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) {
- id, err := d.Receipts.SelectMaxReceiptID(ctx, nil)
- if err != nil {
- return 0, fmt.Errorf("d.Receipts.SelectMaxReceiptID: %w", err)
- }
- return types.StreamPosition(id), nil
-}
-
-func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) {
- id, err := d.Invites.SelectMaxInviteID(ctx, nil)
- if err != nil {
- return 0, fmt.Errorf("d.Invites.SelectMaxInviteID: %w", err)
- }
- return types.StreamPosition(id), nil
-}
-
-func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
- id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil)
- if err != nil {
- return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
- }
- return types.StreamPosition(id), nil
-}
-
-func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
- id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
- if err != nil {
- return 0, fmt.Errorf("d.Invites.SelectMaxAccountDataID: %w", err)
- }
- return types.StreamPosition(id), nil
-}
-
-func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
- id, err := d.NotificationData.SelectMaxID(ctx, nil)
- if err != nil {
- return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
- }
- return types.StreamPosition(id), nil
-}
-
-func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
- return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs)
-}
-
-func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) {
- return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
-}
-
-func (d *Database) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) {
- return d.Memberships.SelectMembershipCount(ctx, nil, roomID, membership, pos)
-}
-
-func (d *Database) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) {
- return d.Memberships.SelectHeroes(ctx, nil, roomID, userID, memberships)
-}
-
-func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
- return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
-}
-
-func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
- return d.Topology.SelectPositionInTopology(ctx, nil, eventID)
-}
-
-func (d *Database) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
- return d.Invites.SelectInviteEventsInRange(ctx, nil, targetUserID, r)
-}
-
-func (d *Database) PeeksInRange(ctx context.Context, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) {
- return d.Peeks.SelectPeeksInRange(ctx, nil, userID, deviceID, r)
-}
-
-func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
- return d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos)
-}
-
-// Events lookups a list of event by their event ID.
-// Returns a list of events matching the requested IDs found in the database.
-// If an event is not found in the database then it will be omitted from the list.
-// Returns an error if there was a problem talking with the database.
-// Does not include any transaction IDs in the returned events.
-func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
- streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
- if err != nil {
- return nil, err
- }
-
- // We don't include a device here as we only include transaction IDs in
- // incremental syncs.
- return d.StreamEventsToEvents(nil, streamEvents), nil
-}
-
-func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
- return d.CurrentRoomState.SelectJoinedUsers(ctx, nil)
-}
-
-func (d *Database) AllJoinedUsersInRoom(ctx context.Context, roomIDs []string) (map[string][]string, error) {
- return d.CurrentRoomState.SelectJoinedUsersInRoom(ctx, nil, roomIDs)
-}
-
-func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
- return d.Peeks.SelectPeekingDevices(ctx, nil)
-}
-
-func (d *Database) SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) {
- return d.CurrentRoomState.SelectSharedUsers(ctx, nil, userID, otherUserIDs)
-}
-
-func (d *Database) GetStateEvent(
- ctx context.Context, roomID, evType, stateKey string,
-) (*gomatrixserverlib.HeaderedEvent, error) {
- return d.CurrentRoomState.SelectStateEvent(ctx, nil, roomID, evType, stateKey)
-}
-
-func (d *Database) GetStateEventsForRoom(
- ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
-) (stateEvents []*gomatrixserverlib.HeaderedEvent, err error) {
- stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter, nil)
- return
-}
-
-// AddInviteEvent stores a new invite event for a user.
-// If the invite was successfully stored this returns the stream ID it was stored at.
-// Returns an error if there was a problem communicating with the database.
-func (d *Database) AddInviteEvent(
- ctx context.Context, inviteEvent *gomatrixserverlib.HeaderedEvent,
-) (sp types.StreamPosition, err error) {
- _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
- return err
- })
- return
-}
-
-// RetireInviteEvent removes an old invite event from the database.
-// Returns an error if there was a problem communicating with the database.
-func (d *Database) RetireInviteEvent(
- ctx context.Context, inviteEventID string,
-) (sp types.StreamPosition, err error) {
- _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID)
- return err
- })
- return
-}
-
-// AddPeek tracks the fact that a user has started peeking.
-// If the peek was successfully stored this returns the stream ID it was stored at.
-// Returns an error if there was a problem communicating with the database.
-func (d *Database) AddPeek(
- ctx context.Context, roomID, userID, deviceID string,
-) (sp types.StreamPosition, err error) {
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID)
- return err
- })
- return
-}
-
-// DeletePeek tracks the fact that a user has stopped peeking from the specified
-// device. If the peeks was successfully deleted this returns the stream ID it was
-// stored at. Returns an error if there was a problem communicating with the database.
-func (d *Database) DeletePeek(
- ctx context.Context, roomID, userID, deviceID string,
-) (sp types.StreamPosition, err error) {
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- sp, err = d.Peeks.DeletePeek(ctx, txn, roomID, userID, deviceID)
- return err
- })
- if err == sql.ErrNoRows {
- sp = 0
- err = nil
- }
- return
-}
-
-// DeletePeeks tracks the fact that a user has stopped peeking from all devices
-// If the peeks was successfully deleted this returns the stream ID it was stored at.
-// Returns an error if there was a problem communicating with the database.
-func (d *Database) DeletePeeks(
- ctx context.Context, roomID, userID string,
-) (sp types.StreamPosition, err error) {
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID)
- return err
- })
- if err == sql.ErrNoRows {
- sp = 0
- err = nil
- }
- return
-}
-
-// GetAccountDataInRange returns all account data for a given user inserted or
-// updated between two given positions
-// Returns a map following the format data[roomID] = []dataTypes
-// If no data is retrieved, returns an empty map
-// If there was an issue with the retrieval, returns an error
-func (d *Database) GetAccountDataInRange(
- ctx context.Context, userID string, r types.Range,
- accountDataFilterPart *gomatrixserverlib.EventFilter,
-) (map[string][]string, types.StreamPosition, error) {
- return d.AccountData.SelectAccountDataInRange(ctx, nil, userID, r, accountDataFilterPart)
-}
-
-// UpsertAccountData keeps track of new or updated account data, by saving the type
-// of the new/updated data, and the user ID and room ID the data is related to (empty)
-// room ID means the data isn't specific to any room)
-// If no data with the given type, user ID and room ID exists in the database,
-// creates a new row, else update the existing one
-// Returns an error if there was an issue with the upsert
-func (d *Database) UpsertAccountData(
- ctx context.Context, userID, roomID, dataType string,
-) (sp types.StreamPosition, err error) {
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
- return err
- })
- return
-}
-
-func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent {
- out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
- for i := 0; i < len(in); i++ {
- out[i] = in[i].HeaderedEvent
- if device != nil && in[i].TransactionID != nil {
- if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
- err := out[i].SetUnsignedField(
- "transaction_id", in[i].TransactionID.TransactionID,
- )
- if err != nil {
- logrus.WithFields(logrus.Fields{
- "event_id": out[i].EventID(),
- }).WithError(err).Warnf("Failed to add transaction ID to event")
- }
- }
- }
- }
- return out
-}
-
-// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
-// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
-// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
-// This function should always be called within a sqlutil.Writer for safety in SQLite.
-func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
- if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
- return err
- }
-
- // Check if we have all of the event's previous events. If an event is
- // missing, add it to the room's backward extremities.
- prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
- if err != nil {
- return err
- }
- var found bool
- for _, eID := range ev.PrevEventIDs() {
- found = false
- for _, prevEv := range prevEvents {
- if eID == prevEv.EventID() {
- found = true
- }
- }
-
- // If the event is missing, consider it a backward extremity.
- if !found {
- if err = d.BackwardExtremities.InsertsBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID(), eID); err != nil {
- return err
- }
- }
- }
-
- return nil
-}
-
-func (d *Database) PurgeRoomState(
- ctx context.Context, roomID string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- // If the event is a create event then we'll delete all of the existing
- // data for the room. The only reason that a create event would be replayed
- // to us in this way is if we're about to receive the entire room state.
- if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
- return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
- }
- return nil
- })
-}
-
-func (d *Database) WriteEvent(
- ctx context.Context,
- ev *gomatrixserverlib.HeaderedEvent,
- addStateEvents []*gomatrixserverlib.HeaderedEvent,
- addStateEventIDs, removeStateEventIDs []string,
- transactionID *api.TransactionID, excludeFromSync bool,
- historyVisibility gomatrixserverlib.HistoryVisibility,
-) (pduPosition types.StreamPosition, returnErr error) {
- returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- var err error
- ev.Visibility = historyVisibility
- pos, err := d.OutputEvents.InsertEvent(
- ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, historyVisibility,
- )
- if err != nil {
- return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
- }
- pduPosition = pos
- var topoPosition types.StreamPosition
- if topoPosition, err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
- return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
- }
-
- if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
- return fmt.Errorf("d.handleBackwardExtremities: %w", err)
- }
-
- if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
- // Nothing to do, the event may have just been a message event.
- return nil
- }
- for i := range addStateEvents {
- addStateEvents[i].Visibility = historyVisibility
- }
- return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition, topoPosition)
- })
-
- return pduPosition, returnErr
-}
-
-// This function should always be called within a sqlutil.Writer for safety in SQLite.
-func (d *Database) updateRoomState(
- ctx context.Context, txn *sql.Tx,
- removedEventIDs []string,
- addedEvents []*gomatrixserverlib.HeaderedEvent,
- pduPosition types.StreamPosition,
- topoPosition types.StreamPosition,
-) error {
- // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
- for _, eventID := range removedEventIDs {
- if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
- return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
- }
- }
-
- for _, event := range addedEvents {
- if event.StateKey() == nil {
- // ignore non state events
- continue
- }
- var membership *string
- if event.Type() == "m.room.member" {
- value, err := event.Membership()
- if err != nil {
- return fmt.Errorf("event.Membership: %w", err)
- }
- membership = &value
- if err = d.Memberships.UpsertMembership(ctx, txn, event, pduPosition, topoPosition); err != nil {
- return fmt.Errorf("d.Memberships.UpsertMembership: %w", err)
- }
- }
-
- if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
- return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
- }
- }
-
- return nil
-}
-
-func (d *Database) GetEventsInTopologicalRange(
- ctx context.Context,
- from, to *types.TopologyToken,
- roomID string,
- filter *gomatrixserverlib.RoomEventFilter,
- backwardOrdering bool,
-) (events []types.StreamEvent, err error) {
- var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
- if backwardOrdering {
- // Backward ordering means the 'from' token has a higher depth than the 'to' token
- minDepth = to.Depth
- maxDepth = from.Depth
- // for cases where we have say 5 events with the same depth, the TopologyToken needs to
- // know which of the 5 the client has seen. This is done by using the PDU position.
- // Events with the same maxDepth but less than this PDU position will be returned.
- maxStreamPosForMaxDepth = from.PDUPosition
- } else {
- // Forward ordering means the 'from' token has a lower depth than the 'to' token.
- minDepth = from.Depth
- maxDepth = to.Depth
- }
-
- // Select the event IDs from the defined range.
- var eIDs []string
- eIDs, err = d.Topology.SelectEventIDsInRange(
- ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
- )
- if err != nil {
- return
- }
-
- // Retrieve the events' contents using their IDs.
- events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
- return
-}
-
-func (d *Database) BackwardExtremitiesForRoom(
- ctx context.Context, roomID string,
-) (backwardExtremities map[string][]string, err error) {
- return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, nil, roomID)
-}
-
-func (d *Database) MaxTopologicalPosition(
- ctx context.Context, roomID string,
-) (types.TopologyToken, error) {
- depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID)
- if err != nil {
- return types.TopologyToken{}, err
- }
- return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
-}
-
-func (d *Database) EventPositionInTopology(
- ctx context.Context, eventID string,
-) (types.TopologyToken, error) {
- depth, stream, err := d.Topology.SelectPositionInTopology(ctx, nil, eventID)
- if err != nil {
- return types.TopologyToken{}, err
- }
- return types.TopologyToken{Depth: depth, PDUPosition: stream}, nil
-}
-
-func (d *Database) StreamToTopologicalPosition(
- ctx context.Context, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
-) (types.TopologyToken, error) {
- topoPos, err := d.Topology.SelectStreamToTopologicalPosition(ctx, nil, roomID, streamPos, backwardOrdering)
- switch {
- case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
- return types.TopologyToken{PDUPosition: streamPos}, nil
- case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
- topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, nil, roomID)
- if err != nil {
- return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
- }
- return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
- case err != nil: // some other error happened
- return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
- default:
- return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
- }
-}
-
-func (d *Database) GetFilter(
- ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
-) error {
- return d.Filter.SelectFilter(ctx, nil, target, localpart, filterID)
-}
-
-func (d *Database) PutFilter(
- ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
-) (string, error) {
- var filterID string
- var err error
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- filterID, err = d.Filter.InsertFilter(ctx, txn, filter, localpart)
- return err
- })
- return filterID, err
-}
-
-func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, redactedBecause *gomatrixserverlib.HeaderedEvent) error {
- redactedEvents, err := d.Events(ctx, []string{redactedEventID})
- if err != nil {
- return err
- }
- if len(redactedEvents) == 0 {
- logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction")
- return nil
- }
- eventToRedact := redactedEvents[0].Unwrap()
- redactionEvent := redactedBecause.Unwrap()
- if err = eventutil.RedactEvent(redactionEvent, eventToRedact); err != nil {
- return err
- }
-
- newEvent := eventToRedact.Headered(redactedBecause.RoomVersion)
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.OutputEvents.UpdateEventJSON(ctx, txn, newEvent)
- })
- return err
-}
-
-// GetBackwardTopologyPos retrieves the backward topology position, i.e. the position of the
-// oldest event in the room's topology.
-func (d *Database) GetBackwardTopologyPos(
- ctx context.Context,
- events []types.StreamEvent,
-) (types.TopologyToken, error) {
- zeroToken := types.TopologyToken{}
- if len(events) == 0 {
- return zeroToken, nil
- }
- pos, spos, err := d.Topology.SelectPositionInTopology(ctx, nil, events[0].EventID())
- if err != nil {
- return zeroToken, err
- }
- tok := types.TopologyToken{Depth: pos, PDUPosition: spos}
- tok.Decrement()
- return tok, nil
-}
-
-// fetchStateEvents converts the set of event IDs into a set of events. It will fetch any which are missing from the database.
-// Returns a map of room ID to list of events.
-func (d *Database) fetchStateEvents(
- ctx context.Context, txn *sql.Tx,
- roomIDToEventIDSet map[string]map[string]bool,
- eventIDToEvent map[string]types.StreamEvent,
-) (map[string][]types.StreamEvent, error) {
- stateBetween := make(map[string][]types.StreamEvent)
- missingEvents := make(map[string][]string)
- for roomID, ids := range roomIDToEventIDSet {
- events := stateBetween[roomID]
- for id, need := range ids {
- if !need {
- continue // deleted state
- }
- e, ok := eventIDToEvent[id]
- if ok {
- events = append(events, e)
- } else {
- m := missingEvents[roomID]
- m = append(m, id)
- missingEvents[roomID] = m
- }
- }
- stateBetween[roomID] = events
- }
-
- if len(missingEvents) > 0 {
- // This happens when add_state_ids has an event ID which is not in the provided range.
- // We need to explicitly fetch them.
- allMissingEventIDs := []string{}
- for _, missingEvIDs := range missingEvents {
- allMissingEventIDs = append(allMissingEventIDs, missingEvIDs...)
- }
- evs, err := d.fetchMissingStateEvents(ctx, txn, allMissingEventIDs)
- if err != nil {
- return nil, err
- }
- // we know we got them all otherwise an error would've been returned, so just loop the events
- for _, ev := range evs {
- roomID := ev.RoomID()
- stateBetween[roomID] = append(stateBetween[roomID], ev)
- }
- }
- return stateBetween, nil
-}
-
-func (d *Database) fetchMissingStateEvents(
- ctx context.Context, txn *sql.Tx, eventIDs []string,
-) ([]types.StreamEvent, error) {
- // Fetch from the events table first so we pick up the stream ID for the
- // event.
- events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
- if err != nil {
- return nil, err
- }
-
- have := map[string]bool{}
- for _, event := range events {
- have[event.EventID()] = true
- }
- var missing []string
- for _, eventID := range eventIDs {
- if !have[eventID] {
- missing = append(missing, eventID)
- }
- }
- if len(missing) == 0 {
- return events, nil
- }
-
- // If they are missing from the events table then they should be state
- // events that we received from outside the main event stream.
- // These should be in the room state table.
- stateEvents, err := d.CurrentRoomState.SelectEventsWithEventIDs(ctx, txn, missing)
-
- if err != nil {
- return nil, err
- }
- if len(stateEvents) != len(missing) {
- logrus.WithContext(ctx).Warnf("Failed to map all event IDs to events (got %d, wanted %d)", len(stateEvents), len(missing))
-
- // TODO: Why is this happening? It's probably the roomserver. Uncomment
- // this error again when we work out what it is and fix it, otherwise we
- // just end up returning lots of 500s to the client and that breaks
- // pretty much everything, rather than just sending what we have.
- //return nil, fmt.Errorf("failed to map all event IDs to events: (got %d, wanted %d)", len(stateEvents), len(missing))
- }
- events = append(events, stateEvents...)
- return events, nil
-}
-
-// GetStateDeltas returns the state deltas between fromPos and toPos,
-// exclusive of oldPos, inclusive of newPos, for the rooms in which
-// the user has new membership events.
-// A list of joined room IDs is also returned in case the caller needs it.
-func (d *Database) GetStateDeltas(
- ctx context.Context, device *userapi.Device,
- r types.Range, userID string,
- stateFilter *gomatrixserverlib.StateFilter,
-) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) {
- // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
- // - Get membership list changes for this user in this sync response
- // - For each room which has membership list changes:
- // * Check if the room is 'newly joined' (insufficient to just check for a join event because we allow dupe joins TODO).
- // If it is, then we need to send the full room state down (and 'limited' is always true).
- // * Check if user is still CURRENTLY invited to the room. If so, add room to 'invited' block.
- // * Check if the user is CURRENTLY (TODO) left/banned. If so, add room to 'archived' block.
- // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
- txn, err := d.readOnlySnapshot(ctx)
- if err != nil {
- return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
- }
- var succeeded bool
- defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
-
- // Look up all memberships for the user. We only care about rooms that a
- // user has ever interacted with — joined to, kicked/banned from, left.
- memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil, nil
- }
- return nil, nil, err
- }
-
- allRoomIDs := make([]string, 0, len(memberships))
- joinedRoomIDs := make([]string, 0, len(memberships))
- for roomID, membership := range memberships {
- allRoomIDs = append(allRoomIDs, roomID)
- if membership == gomatrixserverlib.Join {
- joinedRoomIDs = append(joinedRoomIDs, roomID)
- }
- }
-
- // get all the state events ever (i.e. for all available rooms) between these two positions
- stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil, nil
- }
- return nil, nil, err
- }
- state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil, nil
- }
- return nil, nil, err
- }
-
- // find out which rooms this user is peeking, if any.
- // We do this before joins so any peeks get overwritten
- peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
- if err != nil && err != sql.ErrNoRows {
- return nil, nil, err
- }
-
- // add peek blocks
- for _, peek := range peeks {
- if peek.New {
- // send full room state down instead of a delta
- var s []types.StreamEvent
- s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
- if err != nil {
- if err == sql.ErrNoRows {
- continue
- }
- return nil, nil, err
- }
- state[peek.RoomID] = s
- }
- if !peek.Deleted {
- deltas = append(deltas, types.StateDelta{
- Membership: gomatrixserverlib.Peek,
- StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]),
- RoomID: peek.RoomID,
- })
- }
- }
-
- // handle newly joined rooms and non-joined rooms
- newlyJoinedRooms := make(map[string]bool, len(state))
- for roomID, stateStreamEvents := range state {
- for _, ev := range stateStreamEvents {
- if membership, prevMembership := getMembershipFromEvent(ev.Event, userID); membership != "" {
- if membership == gomatrixserverlib.Join && prevMembership != membership {
- // send full room state down instead of a delta
- var s []types.StreamEvent
- s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter)
- if err != nil {
- if err == sql.ErrNoRows {
- continue
- }
- return nil, nil, err
- }
- state[roomID] = s
- newlyJoinedRooms[roomID] = true
- continue // we'll add this room in when we do joined rooms
- }
-
- deltas = append(deltas, types.StateDelta{
- Membership: membership,
- MembershipPos: ev.StreamPosition,
- StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
- RoomID: roomID,
- })
- break
- }
- }
- }
-
- // Add in currently joined rooms
- for _, joinedRoomID := range joinedRoomIDs {
- deltas = append(deltas, types.StateDelta{
- Membership: gomatrixserverlib.Join,
- StateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
- RoomID: joinedRoomID,
- NewlyJoined: newlyJoinedRooms[joinedRoomID],
- })
- }
-
- succeeded = true
- return deltas, joinedRoomIDs, nil
-}
-
-// GetStateDeltasForFullStateSync is a variant of getStateDeltas used for /sync
-// requests with full_state=true.
-// Fetches full state for all joined rooms and uses selectStateInRange to get
-// updates for other rooms.
-func (d *Database) GetStateDeltasForFullStateSync(
- ctx context.Context, device *userapi.Device,
- r types.Range, userID string,
- stateFilter *gomatrixserverlib.StateFilter,
-) ([]types.StateDelta, []string, error) {
- txn, err := d.readOnlySnapshot(ctx)
- if err != nil {
- return nil, nil, fmt.Errorf("d.readOnlySnapshot: %w", err)
- }
- var succeeded bool
- defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err)
-
- // Look up all memberships for the user. We only care about rooms that a
- // user has ever interacted with — joined to, kicked/banned from, left.
- memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil, nil
- }
- return nil, nil, err
- }
-
- allRoomIDs := make([]string, 0, len(memberships))
- joinedRoomIDs := make([]string, 0, len(memberships))
- for roomID, membership := range memberships {
- allRoomIDs = append(allRoomIDs, roomID)
- if membership == gomatrixserverlib.Join {
- joinedRoomIDs = append(joinedRoomIDs, roomID)
- }
- }
-
- // Use a reasonable initial capacity
- deltas := make(map[string]types.StateDelta)
-
- peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
- if err != nil && err != sql.ErrNoRows {
- return nil, nil, err
- }
-
- // Add full states for all peeking rooms
- for _, peek := range peeks {
- if !peek.Deleted {
- s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
- if stateErr != nil {
- if stateErr == sql.ErrNoRows {
- continue
- }
- return nil, nil, stateErr
- }
- deltas[peek.RoomID] = types.StateDelta{
- Membership: gomatrixserverlib.Peek,
- StateEvents: d.StreamEventsToEvents(device, s),
- RoomID: peek.RoomID,
- }
- }
- }
-
- // Get all the state events ever between these two positions
- stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil, nil
- }
- return nil, nil, err
- }
- state, err := d.fetchStateEvents(ctx, txn, stateNeeded, eventMap)
- if err != nil {
- if err == sql.ErrNoRows {
- return nil, nil, nil
- }
- return nil, nil, err
- }
-
- for roomID, stateStreamEvents := range state {
- for _, ev := range stateStreamEvents {
- if membership, _ := getMembershipFromEvent(ev.Event, userID); membership != "" {
- if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
- deltas[roomID] = types.StateDelta{
- Membership: membership,
- MembershipPos: ev.StreamPosition,
- StateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
- RoomID: roomID,
- }
- }
-
- break
- }
- }
- }
-
- // Add full states for all joined rooms
- for _, joinedRoomID := range joinedRoomIDs {
- s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter)
- if stateErr != nil {
- if stateErr == sql.ErrNoRows {
- continue
- }
- return nil, nil, stateErr
- }
- deltas[joinedRoomID] = types.StateDelta{
- Membership: gomatrixserverlib.Join,
- StateEvents: d.StreamEventsToEvents(device, s),
- RoomID: joinedRoomID,
- }
- }
-
- // Create a response array.
- result := make([]types.StateDelta, len(deltas))
- i := 0
- for _, delta := range deltas {
- result[i] = delta
- i++
- }
-
- succeeded = true
- return result, joinedRoomIDs, nil
-}
-
-func (d *Database) currentStateStreamEventsForRoom(
- ctx context.Context, txn *sql.Tx, roomID string,
- stateFilter *gomatrixserverlib.StateFilter,
-) ([]types.StreamEvent, error) {
- allState, err := d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter, nil)
- if err != nil {
- return nil, err
- }
- s := make([]types.StreamEvent, len(allState))
- for i := 0; i < len(s); i++ {
- s[i] = types.StreamEvent{HeaderedEvent: allState[i], StreamPosition: 0}
- }
- return s, nil
-}
-
-func (d *Database) StoreNewSendForDeviceMessage(
- ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
-) (newPos types.StreamPosition, err error) {
- j, err := json.Marshal(event)
- if err != nil {
- return 0, err
- }
- // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
- // that we don't lock the table for writes in more than one place.
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
- ctx, txn, userID, deviceID, string(j),
- )
- return err
- })
- if err != nil {
- return 0, err
- }
- return newPos, nil
-}
-
-func (d *Database) SendToDeviceUpdatesForSync(
- ctx context.Context,
- userID, deviceID string,
- from, to types.StreamPosition,
-) (types.StreamPosition, []types.SendToDeviceEvent, error) {
- // First of all, get our send-to-device updates for this user.
- lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
- if err != nil {
- return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
- }
- // If there's nothing to do then stop here.
- if len(events) == 0 {
- return to, nil, nil
- }
- return lastPos, events, nil
-}
-
-func (d *Database) CleanSendToDeviceUpdates(
- ctx context.Context,
- userID, deviceID string, before types.StreamPosition,
-) (err error) {
- if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
- }); err != nil {
- logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
- return err
- }
- return nil
-}
-
-// getMembershipFromEvent returns the value of content.membership iff the event is a state event
-// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
-func getMembershipFromEvent(ev *gomatrixserverlib.Event, userID string) (string, string) {
- if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
- return "", ""
- }
- membership, err := ev.Membership()
- if err != nil {
- return "", ""
- }
- prevMembership := gjson.GetBytes(ev.Unsigned(), "prev_content.membership").Str
- return membership, prevMembership
-}
-
-// StoreReceipt stores user receipts
-func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) {
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- pos, err = d.Receipts.UpsertReceipt(ctx, txn, roomId, receiptType, userId, eventId, timestamp)
- return err
- })
- return
-}
-
-func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) {
- _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, nil, roomIDs, streamPos)
- return receipts, err
-}
-
-func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, txn, userID, roomID, notificationCount, highlightCount)
- return err
- })
- return
-}
-
-func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) {
- roomIDs := make([]string, 0, len(rooms))
- for roomID, membership := range rooms {
- if membership != gomatrixserverlib.Join {
- continue
- }
- roomIDs = append(roomIDs, roomID)
- }
- return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs)
-}
-
-func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
- return d.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
-}
-
-func (d *Database) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) {
- return d.OutputEvents.SelectContextBeforeEvent(ctx, nil, id, roomID, filter)
-}
-func (d *Database) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) {
- return d.OutputEvents.SelectContextAfterEvent(ctx, nil, id, roomID, filter)
-}
-
-func (d *Database) IgnoresForUser(ctx context.Context, userID string) (*types.IgnoredUsers, error) {
- return d.Ignores.SelectIgnores(ctx, nil, userID)
-}
-
-func (d *Database) UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Ignores.UpsertIgnores(ctx, txn, userID, ignores)
- })
-}
-
-func (d *Database) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) {
- var pos types.StreamPosition
- var err error
- _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- pos, err = d.Presence.UpsertPresence(ctx, txn, userID, statusMsg, presence, lastActiveTS, fromSync)
- return nil
- })
- return pos, err
-}
-
-func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) {
- return d.Presence.GetPresenceForUser(ctx, nil, userID)
-}
-
-func (d *Database) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) {
- return d.Presence.GetPresenceAfter(ctx, nil, after, filter)
-}
-
-func (d *Database) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
- return d.Presence.GetMaxPresenceID(ctx, nil)
-}
-
-func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) {
- return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
-}
-
-func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
- return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
- gomatrixserverlib.MRoomName,
- gomatrixserverlib.MRoomTopic,
- "m.room.message",
- })
-}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index ba6d8126..c4019fed 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -367,7 +367,13 @@ func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
for start < len(eventIDs) {
n := minOfInts(len(eventIDs)-start, 999)
query := strings.Replace(selectEventsWithEventIDsSQL, "($1)", sqlutil.QueryVariadic(n), 1)
- rows, err := txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
+ var rows *sql.Rows
+ var err error
+ if txn == nil {
+ rows, err = s.db.QueryContext(ctx, query, iEventIDs[start:start+n]...)
+ } else {
+ rows, err = txn.QueryContext(ctx, query, iEventIDs[start:start+n]...)
+ }
if err != nil {
return nil, err
}
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
index 58ab8461..e2dbcd5c 100644
--- a/syncapi/storage/sqlite3/invites_table.go
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -50,7 +50,7 @@ const deleteInviteEventSQL = "" +
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false"
const selectInviteEventsInRangeSQL = "" +
- "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +
+ "SELECT id, room_id, headered_event_json, deleted FROM syncapi_invite_events" +
" WHERE target_user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC"
@@ -132,23 +132,28 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
// active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
-) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, error) {
+) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
+ var lastPos types.StreamPosition
stmt := sqlutil.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil {
- return nil, nil, err
+ return nil, nil, lastPos, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectInviteEventsInRange: rows.close() failed")
result := map[string]*gomatrixserverlib.HeaderedEvent{}
retired := map[string]*gomatrixserverlib.HeaderedEvent{}
for rows.Next() {
var (
+ id types.StreamPosition
roomID string
eventJSON []byte
deleted bool
)
- if err = rows.Scan(&roomID, &eventJSON, &deleted); err != nil {
- return nil, nil, err
+ if err = rows.Scan(&id, &roomID, &eventJSON, &deleted); err != nil {
+ return nil, nil, lastPos, err
+ }
+ if id > lastPos {
+ lastPos = id
}
// if we have seen this room before, it has a higher stream position and hence takes priority
@@ -161,15 +166,19 @@ func (s *inviteEventsStatements) SelectInviteEventsInRange(
var event *gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(eventJSON, &event); err != nil {
- return nil, nil, err
+ return nil, nil, lastPos, err
}
+
if deleted {
retired[roomID] = event
} else {
result[roomID] = event
}
}
- return result, retired, nil
+ if lastPos == 0 {
+ lastPos = r.To
+ }
+ return result, retired, lastPos, nil
}
func (s *inviteEventsStatements) SelectMaxInviteID(
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index a84e2bd1..0879030a 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -49,6 +49,20 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
return &d, nil
}
+func (d *SyncServerDatasource) NewDatabaseSnapshot(ctx context.Context) (*shared.DatabaseTransaction, error) {
+ return &shared.DatabaseTransaction{
+ Database: &d.Database,
+ // not setting a transaction because SQLite doesn't support it
+ }, nil
+}
+
+func (d *SyncServerDatasource) NewDatabaseTransaction(ctx context.Context) (*shared.DatabaseTransaction, error) {
+ return &shared.DatabaseTransaction{
+ Database: &d.Database,
+ // not setting a transaction because SQLite doesn't support it
+ }, nil
+}
+
func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
if err = d.streamID.Prepare(d.db); err != nil {
return err
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index a62818e9..5ff185a3 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -60,6 +60,17 @@ func TestWriteEvents(t *testing.T) {
})
}
+func WithSnapshot(t *testing.T, db storage.Database, f func(snapshot storage.DatabaseTransaction)) {
+ snapshot, err := db.NewDatabaseSnapshot(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ f(snapshot)
+ if err := snapshot.Rollback(); err != nil {
+ t.Fatal(err)
+ }
+}
+
// These tests assert basic functionality of RecentEvents for PDUs
func TestRecentEventsPDU(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@@ -79,10 +90,13 @@ func TestRecentEventsPDU(t *testing.T) {
// dummy room to make sure SQL queries are filtering on room ID
MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
- latest, err := db.MaxStreamPositionForPDUs(ctx)
- if err != nil {
- t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
- }
+ var latest types.StreamPosition
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ var err error
+ if latest, err = snapshot.MaxStreamPositionForPDUs(ctx); err != nil {
+ t.Fatal("failed to get MaxStreamPositionForPDUs: %w", err)
+ }
+ })
testCases := []struct {
Name string
@@ -140,14 +154,19 @@ func TestRecentEventsPDU(t *testing.T) {
tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) {
var filter gomatrixserverlib.RoomEventFilter
+ var gotEvents []types.StreamEvent
+ var limited bool
filter.Limit = tc.Limit
- gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
- From: tc.From,
- To: tc.To,
- }, &filter, !tc.ReverseOrder, true)
- if err != nil {
- st.Fatalf("failed to do sync: %s", err)
- }
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ var err error
+ gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{
+ From: tc.From,
+ To: tc.To,
+ }, &filter, !tc.ReverseOrder, true)
+ if err != nil {
+ st.Fatalf("failed to do sync: %s", err)
+ }
+ })
if limited != tc.WantLimited {
st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
}
@@ -178,22 +197,24 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
events := r.Events()
_ = MustWriteEvents(t, db, events)
- from, err := db.MaxTopologicalPosition(ctx, r.ID)
- if err != nil {
- t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
- }
- t.Logf("max topo pos = %+v", from)
- // head towards the beginning of time
- to := types.TopologyToken{}
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
+ if err != nil {
+ t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
+ }
+ t.Logf("max topo pos = %+v", from)
+ // head towards the beginning of time
+ to := types.TopologyToken{}
- // backpaginate 5 messages starting at the latest position.
- filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
- paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
- if err != nil {
- t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
- }
- gots := db.StreamEventsToEvents(nil, paginatedEvents)
- test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
+ // backpaginate 5 messages starting at the latest position.
+ filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
+ paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
+ if err != nil {
+ t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
+ }
+ gots := snapshot.StreamEventsToEvents(nil, paginatedEvents)
+ test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
+ })
})
}
@@ -414,13 +435,16 @@ func TestSendToDeviceBehaviour(t *testing.T) {
defer closeBase()
// At this point there should be no messages. We haven't sent anything
// yet.
- _, events, err := db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 {
- t.Fatal("first call should have no updates")
- }
+
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, 100)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("first call should have no updates")
+ }
+ })
// Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, alice.ID, deviceID, gomatrixserverlib.SendToDeviceEvent{
@@ -432,51 +456,58 @@ func TestSendToDeviceBehaviour(t *testing.T) {
t.Fatal(err)
}
- // At this point we should get exactly one message. We're sending the sync position
- // that we were given from the update and the send-to-device update will be updated
- // in the database to reflect that this was the sync position we sent the message at.
- streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
- if err != nil {
- t.Fatal(err)
- }
- if count := len(events); count != 1 {
- t.Fatalf("second call should have one update, got %d", count)
- }
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ // At this point we should get exactly one message. We're sending the sync position
+ // that we were given from the update and the send-to-device update will be updated
+ // in the database to reflect that this was the sync position we sent the message at.
+ var events []types.SendToDeviceEvent
+ streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if count := len(events); count != 1 {
+ t.Fatalf("second call should have one update, got %d", count)
+ }
+
+ // At this point we should still have one message because we haven't progressed the
+ // sync position yet. This is equivalent to the client failing to /sync and retrying
+ // with the same position.
+ streamPos, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 1 {
+ t.Fatal("third call should have one update still")
+ }
+ })
- // At this point we should still have one message because we haven't progressed the
- // sync position yet. This is equivalent to the client failing to /sync and retrying
- // with the same position.
- streamPos, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, streamPos)
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 1 {
- t.Fatal("third call should have one update still")
- }
err = db.CleanSendToDeviceUpdates(context.Background(), alice.ID, deviceID, streamPos)
if err != nil {
return
}
- // At this point we should now have no updates, because we've progressed the sync
- // position. Therefore the update from before will not be sent again.
- _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 {
- t.Fatal("fourth call should have no updates")
- }
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ // At this point we should now have no updates, because we've progressed the sync
+ // position. Therefore the update from before will not be sent again.
+ var events []types.SendToDeviceEvent
+ _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("fourth call should have no updates")
+ }
- // At this point we should still have no updates, because no new updates have been
- // sent.
- _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
- if err != nil {
- t.Fatal(err)
- }
- if len(events) != 0 {
- t.Fatal("fifth call should have no updates")
- }
+ // At this point we should still have no updates, because no new updates have been
+ // sent.
+ _, events, err = snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, streamPos, streamPos+10)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(events) != 0 {
+ t.Fatal("fifth call should have no updates")
+ }
+ })
// Send some more messages and verify the ordering is correct ("in order of arrival")
var lastPos types.StreamPosition = 0
@@ -492,18 +523,20 @@ func TestSendToDeviceBehaviour(t *testing.T) {
lastPos = streamPos
}
- _, events, err = db.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
- if err != nil {
- t.Fatalf("unable to get events: %v", err)
- }
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ _, events, err := snapshot.SendToDeviceUpdatesForSync(ctx, alice.ID, deviceID, 0, lastPos)
+ if err != nil {
+ t.Fatalf("unable to get events: %v", err)
+ }
- for i := 0; i < 10; i++ {
- want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
- got := events[i].Content
- if !bytes.Equal(got, want) {
- t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
+ for i := 0; i < 10; i++ {
+ want := json.RawMessage(fmt.Sprintf(`{"count":%d}`, i))
+ got := events[i].Content
+ if !bytes.Equal(got, want) {
+ t.Fatalf("messages are out of order\nwant: %s\ngot: %s", string(want), string(got))
+ }
}
- }
+ })
})
}
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index 89cb537a..2fdc3cfb 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -37,7 +37,7 @@ type Invites interface {
DeleteInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string) (types.StreamPosition, error)
// SelectInviteEventsInRange returns a map of room ID to invite events. If multiple invite/retired invites exist in the given range, return the latest value
// for the room.
- SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, err error)
+ SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error)
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go
index 0297d5c2..3f2f7d13 100644
--- a/syncapi/streams/stream_accountdata.go
+++ b/syncapi/streams/stream_accountdata.go
@@ -5,22 +5,25 @@ import (
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type AccountDataStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
userAPI userapi.SyncUserAPI
}
-func (p *AccountDataStreamProvider) Setup() {
- p.StreamProvider.Setup()
+func (p *AccountDataStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
- id, err := p.DB.MaxStreamPositionForAccountData(context.Background())
+ id, err := snapshot.MaxStreamPositionForAccountData(ctx)
if err != nil {
panic(err)
}
@@ -29,13 +32,15 @@ func (p *AccountDataStreamProvider) Setup() {
func (p *AccountDataStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *AccountDataStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
@@ -44,7 +49,7 @@ func (p *AccountDataStreamProvider) IncrementalSync(
To: to,
}
- dataTypes, pos, err := p.DB.GetAccountDataInRange(
+ dataTypes, pos, err := snapshot.GetAccountDataInRange(
ctx, req.Device.UserID, r, &req.Filter.AccountData,
)
if err != nil {
diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go
index 5448ee5b..7996c203 100644
--- a/syncapi/streams/stream_devicelist.go
+++ b/syncapi/streams/stream_devicelist.go
@@ -6,17 +6,19 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type DeviceListStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
rsAPI api.SyncRoomserverAPI
keyAPI keyapi.SyncKeyAPI
}
func (p *DeviceListStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
return p.LatestPosition(ctx)
@@ -24,11 +26,12 @@ func (p *DeviceListStreamProvider) CompleteSync(
func (p *DeviceListStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
var err error
- to, _, err = internal.DeviceListCatchup(context.Background(), p.DB, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
+ to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go
index 925da32f..17b3b843 100644
--- a/syncapi/streams/stream_invite.go
+++ b/syncapi/streams/stream_invite.go
@@ -9,20 +9,23 @@ import (
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type InviteStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
}
-func (p *InviteStreamProvider) Setup() {
- p.StreamProvider.Setup()
+func (p *InviteStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
- id, err := p.DB.MaxStreamPositionForInvites(context.Background())
+ id, err := snapshot.MaxStreamPositionForInvites(ctx)
if err != nil {
panic(err)
}
@@ -31,13 +34,15 @@ func (p *InviteStreamProvider) Setup() {
func (p *InviteStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *InviteStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
@@ -46,7 +51,7 @@ func (p *InviteStreamProvider) IncrementalSync(
To: to,
}
- invites, retiredInvites, err := p.DB.InviteEventsInRange(
+ invites, retiredInvites, maxID, err := snapshot.InviteEventsInRange(
ctx, req.Device.UserID, r,
)
if err != nil {
@@ -86,5 +91,5 @@ func (p *InviteStreamProvider) IncrementalSync(
}
}
- return to
+ return maxID
}
diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go
index 33872734..5a81fd09 100644
--- a/syncapi/streams/stream_notificationdata.go
+++ b/syncapi/streams/stream_notificationdata.go
@@ -3,17 +3,23 @@ package streams
import (
"context"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type NotificationDataStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
}
-func (p *NotificationDataStreamProvider) Setup() {
- p.StreamProvider.Setup()
+func (p *NotificationDataStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
- id, err := p.DB.MaxStreamPositionForNotificationData(context.Background())
+ p.latestMutex.Lock()
+ defer p.latestMutex.Unlock()
+
+ id, err := snapshot.MaxStreamPositionForNotificationData(ctx)
if err != nil {
panic(err)
}
@@ -22,20 +28,22 @@ func (p *NotificationDataStreamProvider) Setup() {
func (p *NotificationDataStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, _ types.StreamPosition,
) types.StreamPosition {
// Get the unread notifications for rooms in our join response.
// This is to ensure clients always have an unread notification section
// and can display the correct numbers.
- countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
+ countsByRoom, err := snapshot.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms)
if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed")
return from
diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go
index 0ab6de88..89c5ba35 100644
--- a/syncapi/streams/stream_pdu.go
+++ b/syncapi/streams/stream_pdu.go
@@ -5,7 +5,6 @@ import (
"database/sql"
"fmt"
"sort"
- "sync"
"time"
"github.com/matrix-org/dendrite/internal/caching"
@@ -18,7 +17,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
- "go.uber.org/atomic"
"github.com/matrix-org/dendrite/syncapi/notifier"
)
@@ -33,44 +31,23 @@ const PDU_STREAM_WORKERS = 256
const PDU_STREAM_QUEUESIZE = PDU_STREAM_WORKERS * 8
type PDUStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
- tasks chan func()
- workers atomic.Int32
// userID+deviceID -> lazy loading cache
lazyLoadCache caching.LazyLoadCache
rsAPI roomserverAPI.SyncRoomserverAPI
notifier *notifier.Notifier
}
-func (p *PDUStreamProvider) worker() {
- defer p.workers.Dec()
- for {
- select {
- case f := <-p.tasks:
- f()
- case <-time.After(time.Second * 10):
- return
- }
- }
-}
-
-func (p *PDUStreamProvider) queue(f func()) {
- if p.workers.Load() < PDU_STREAM_WORKERS {
- p.workers.Inc()
- go p.worker()
- }
- p.tasks <- f
-}
-
-func (p *PDUStreamProvider) Setup() {
- p.StreamProvider.Setup()
- p.tasks = make(chan func(), PDU_STREAM_QUEUESIZE)
+func (p *PDUStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
- id, err := p.DB.MaxStreamPositionForPDUs(context.Background())
+ id, err := snapshot.MaxStreamPositionForPDUs(ctx)
if err != nil {
panic(err)
}
@@ -79,6 +56,7 @@ func (p *PDUStreamProvider) Setup() {
func (p *PDUStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
from := types.StreamPosition(0)
@@ -94,7 +72,7 @@ func (p *PDUStreamProvider) CompleteSync(
}
// Extract room state and recent events for all rooms the user is joined to.
- joinedRoomIDs, err := p.DB.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join)
+ joinedRoomIDs, err := snapshot.RoomIDsWithMembership(ctx, req.Device.UserID, gomatrixserverlib.Join)
if err != nil {
req.Log.WithError(err).Error("p.DB.RoomIDsWithMembership failed")
return from
@@ -103,7 +81,7 @@ func (p *PDUStreamProvider) CompleteSync(
stateFilter := req.Filter.Room.State
eventFilter := req.Filter.Room.Timeline
- if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
+ if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
req.Log.WithError(err).Error("unable to update event filter with ignored users")
}
@@ -117,33 +95,20 @@ func (p *PDUStreamProvider) CompleteSync(
}
// Build up a /sync response. Add joined rooms.
- var reqMutex sync.Mutex
- var reqWaitGroup sync.WaitGroup
- reqWaitGroup.Add(len(joinedRoomIDs))
- for _, room := range joinedRoomIDs {
- roomID := room
- p.queue(func() {
- defer reqWaitGroup.Done()
-
- jr, jerr := p.getJoinResponseForCompleteSync(
- ctx, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
- )
- if jerr != nil {
- req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
- return
- }
-
- reqMutex.Lock()
- defer reqMutex.Unlock()
- req.Response.Rooms.Join[roomID] = *jr
- req.Rooms[roomID] = gomatrixserverlib.Join
- })
+ for _, roomID := range joinedRoomIDs {
+ jr, jerr := p.getJoinResponseForCompleteSync(
+ ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false,
+ )
+ if jerr != nil {
+ req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed")
+ continue // return from
+ }
+ req.Response.Rooms.Join[roomID] = *jr
+ req.Rooms[roomID] = gomatrixserverlib.Join
}
- reqWaitGroup.Wait()
-
// Add peeked rooms.
- peeks, err := p.DB.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r)
+ peeks, err := snapshot.PeeksInRange(ctx, req.Device.UserID, req.Device.ID, r)
if err != nil {
req.Log.WithError(err).Error("p.DB.PeeksInRange failed")
return from
@@ -152,11 +117,11 @@ func (p *PDUStreamProvider) CompleteSync(
if !peek.Deleted {
var jr *types.JoinResponse
jr, err = p.getJoinResponseForCompleteSync(
- ctx, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
+ ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true,
)
if err != nil {
req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed")
- return from
+ continue // return from
}
req.Response.Rooms.Peek[peek.RoomID] = *jr
}
@@ -167,6 +132,7 @@ func (p *PDUStreamProvider) CompleteSync(
func (p *PDUStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) (newPos types.StreamPosition) {
@@ -184,12 +150,12 @@ func (p *PDUStreamProvider) IncrementalSync(
eventFilter := req.Filter.Room.Timeline
if req.WantFullState {
- if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
+ if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed")
return
}
} else {
- if stateDeltas, syncJoinedRooms, err = p.DB.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
+ if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil {
req.Log.WithError(err).Error("p.DB.GetStateDeltas failed")
return
}
@@ -203,7 +169,7 @@ func (p *PDUStreamProvider) IncrementalSync(
return to
}
- if err = p.addIgnoredUsersToFilter(ctx, req, &eventFilter); err != nil {
+ if err = p.addIgnoredUsersToFilter(ctx, snapshot, req, &eventFilter); err != nil {
req.Log.WithError(err).Error("unable to update event filter with ignored users")
}
@@ -222,7 +188,7 @@ func (p *PDUStreamProvider) IncrementalSync(
}
}
var pos types.StreamPosition
- if pos, err = p.addRoomDeltaToResponse(ctx, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
+ if pos, err = p.addRoomDeltaToResponse(ctx, snapshot, req.Device, newRange, delta, &eventFilter, &stateFilter, req.Response); err != nil {
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
return to
}
@@ -244,6 +210,7 @@ func (p *PDUStreamProvider) IncrementalSync(
// nolint:gocyclo
func (p *PDUStreamProvider) addRoomDeltaToResponse(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
device *userapi.Device,
r types.Range,
delta types.StateDelta,
@@ -260,7 +227,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// This is all "okay" assuming history_visibility == "shared" which it is by default.
r.To = delta.MembershipPos
}
- recentStreamEvents, limited, err := p.DB.RecentEvents(
+ recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, delta.RoomID, r,
eventFilter, true, true,
)
@@ -270,9 +237,9 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err)
}
- recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
+ recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
delta.StateEvents = removeDuplicates(delta.StateEvents, recentEvents) // roll back
- prevBatch, err := p.DB.GetBackwardTopologyPos(ctx, recentStreamEvents)
+ prevBatch, err := snapshot.GetBackwardTopologyPos(ctx, recentStreamEvents)
if err != nil {
return r.From, fmt.Errorf("p.DB.GetBackwardTopologyPos: %w", err)
}
@@ -291,7 +258,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
latestPosition := r.To
updateLatestPosition := func(mostRecentEventID string) {
var pos types.StreamPosition
- if _, pos, err = p.DB.PositionInTopology(ctx, mostRecentEventID); err == nil {
+ if _, pos, err = snapshot.PositionInTopology(ctx, mostRecentEventID); err == nil {
switch {
case r.Backwards && pos < latestPosition:
fallthrough
@@ -303,7 +270,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
if stateFilter.LazyLoadMembers {
delta.StateEvents, err = p.lazyLoadMembers(
- ctx, delta.RoomID, true, limited, stateFilter,
+ ctx, snapshot, delta.RoomID, true, limited, stateFilter,
device, recentEvents, delta.StateEvents,
)
if err != nil && err != sql.ErrNoRows {
@@ -320,7 +287,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
}
// Applies the history visibility rules
- events, err := applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
+ events, err := applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, delta.RoomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
@@ -336,7 +303,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
if hasMembershipChange {
- p.addRoomSummary(ctx, jr, delta.RoomID, device.UserID, latestPosition)
+ p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition)
}
jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync)
@@ -376,7 +343,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
// sure we always return the required events in the timeline.
func applyHistoryVisibilityFilter(
ctx context.Context,
- db storage.Database,
+ snapshot storage.DatabaseTransaction,
rsAPI roomserverAPI.SyncRoomserverAPI,
roomID, userID string,
limit int,
@@ -384,7 +351,7 @@ func applyHistoryVisibilityFilter(
) ([]*gomatrixserverlib.HeaderedEvent, error) {
// We need to make sure we always include the latest states events, if they are in the timeline.
// We grep at least limit * 2 events, to ensure we really get the needed events.
- stateEvents, err := db.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
+ stateEvents, err := snapshot.CurrentState(ctx, roomID, &gomatrixserverlib.StateFilter{Limit: limit * 2}, nil)
if err != nil {
// Not a fatal error, we can continue without the stateEvents,
// they are only needed if there are state events in the timeline.
@@ -395,7 +362,7 @@ func applyHistoryVisibilityFilter(
alwaysIncludeIDs[ev.EventID()] = struct{}{}
}
startTime := time.Now()
- events, err := internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
+ events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
if err != nil {
return nil, err
}
@@ -408,10 +375,10 @@ func applyHistoryVisibilityFilter(
return events, nil
}
-func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
+func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) {
// Work out how many members are in the room.
- joinedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
- invitedCount, _ := p.DB.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
+ joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition)
+ invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition)
jr.Summary.JoinedMemberCount = &joinedCount
jr.Summary.InvitedMemberCount = &invitedCount
@@ -439,7 +406,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
}
}
}
- heroes, err := p.DB.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
+ heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"})
if err != nil {
return
}
@@ -449,6 +416,7 @@ func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, jr *types.JoinRe
func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
roomID string,
r types.Range,
stateFilter *gomatrixserverlib.StateFilter,
@@ -460,7 +428,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
jr = types.NewJoinResponse()
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
- recentStreamEvents, limited, err := p.DB.RecentEvents(
+ recentStreamEvents, limited, err := snapshot.RecentEvents(
ctx, roomID, r, eventFilter, true, true,
)
if err != nil {
@@ -484,7 +452,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
}
- stateEvents, err := p.DB.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
+ stateEvents, err := snapshot.CurrentState(ctx, roomID, stateFilter, excludingEventIDs)
if err != nil {
return
}
@@ -494,7 +462,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
var prevBatch *types.TopologyToken
if len(recentStreamEvents) > 0 {
var backwardTopologyPos, backwardStreamPos types.StreamPosition
- backwardTopologyPos, backwardStreamPos, err = p.DB.PositionInTopology(ctx, recentStreamEvents[0].EventID())
+ backwardTopologyPos, backwardStreamPos, err = snapshot.PositionInTopology(ctx, recentStreamEvents[0].EventID())
if err != nil {
return
}
@@ -505,18 +473,18 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
prevBatch.Decrement()
}
- p.addRoomSummary(ctx, jr, roomID, device.UserID, r.From)
+ p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From)
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs, but we do it anyway because Sytest demands it for:
// "Can sync a room with a message with a transaction id" - which does a complete sync to check.
- recentEvents := p.DB.StreamEventsToEvents(device, recentStreamEvents)
+ recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
events := recentEvents
// Only apply history visibility checks if the response is for joined rooms
if !isPeek {
- events, err = applyHistoryVisibilityFilter(ctx, p.DB, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
+ events, err = applyHistoryVisibilityFilter(ctx, snapshot, p.rsAPI, roomID, device.UserID, eventFilter.Limit, recentEvents)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
}
@@ -530,7 +498,8 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
if err != nil {
return nil, err
}
- stateEvents, err = p.lazyLoadMembers(ctx, roomID,
+ stateEvents, err = p.lazyLoadMembers(
+ ctx, snapshot, roomID,
false, limited, stateFilter,
device, recentEvents, stateEvents,
)
@@ -549,7 +518,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
}
func (p *PDUStreamProvider) lazyLoadMembers(
- ctx context.Context, roomID string,
+ ctx context.Context, snapshot storage.DatabaseTransaction, roomID string,
incremental, limited bool, stateFilter *gomatrixserverlib.StateFilter,
device *userapi.Device,
timelineEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
@@ -598,7 +567,7 @@ func (p *PDUStreamProvider) lazyLoadMembers(
filter.Limit = stateFilter.Limit
filter.Senders = &wantUsers
filter.Types = &[]string{gomatrixserverlib.MRoomMember}
- memberships, err := p.DB.GetStateEventsForRoom(ctx, roomID, &filter)
+ memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter)
if err != nil {
return stateEvents, err
}
@@ -612,8 +581,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// addIgnoredUsersToFilter adds ignored users to the eventfilter and
// the syncreq itself for further use in streams.
-func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
- ignores, err := p.DB.IgnoresForUser(ctx, req.Device.UserID)
+func (p *PDUStreamProvider) addIgnoredUsersToFilter(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, eventFilter *gomatrixserverlib.RoomEventFilter) error {
+ ignores, err := snapshot.IgnoresForUser(ctx, req.Device.UserID)
if err != nil {
if err == sql.ErrNoRows {
return nil
diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go
index 15db4d30..81cea7d5 100644
--- a/syncapi/streams/stream_presence.go
+++ b/syncapi/streams/stream_presence.go
@@ -23,20 +23,26 @@ import (
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/syncapi/notifier"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type PresenceStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
// cache contains previously sent presence updates to avoid unneeded updates
cache sync.Map
notifier *notifier.Notifier
}
-func (p *PresenceStreamProvider) Setup() {
- p.StreamProvider.Setup()
+func (p *PresenceStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
- id, err := p.DB.MaxStreamPositionForPresence(context.Background())
+ p.latestMutex.Lock()
+ defer p.latestMutex.Unlock()
+
+ id, err := snapshot.MaxStreamPositionForPresence(ctx)
if err != nil {
panic(err)
}
@@ -45,18 +51,20 @@ func (p *PresenceStreamProvider) Setup() {
func (p *PresenceStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *PresenceStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
// We pull out a larger number than the filter asks for, since we're filtering out events later
- presences, err := p.DB.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
+ presences, err := snapshot.PresenceAfter(ctx, from, gomatrixserverlib.EventFilter{Limit: 1000})
if err != nil {
req.Log.WithError(err).Error("p.DB.PresenceAfter failed")
return from
@@ -84,7 +92,7 @@ func (p *PresenceStreamProvider) IncrementalSync(
}
// Bear in mind that this might return nil, but at least populating
// a nil means that there's a map entry so we won't repeat this call.
- presences[roomUsers[i]], err = p.DB.GetPresence(ctx, roomUsers[i])
+ presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i])
if err != nil {
req.Log.WithError(err).Error("unable to query presence for user")
return from
diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go
index f4e84c7d..8818a553 100644
--- a/syncapi/streams/stream_receipt.go
+++ b/syncapi/streams/stream_receipt.go
@@ -4,18 +4,24 @@ import (
"context"
"encoding/json"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type ReceiptStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
}
-func (p *ReceiptStreamProvider) Setup() {
- p.StreamProvider.Setup()
+func (p *ReceiptStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
- id, err := p.DB.MaxStreamPositionForReceipts(context.Background())
+ p.latestMutex.Lock()
+ defer p.latestMutex.Unlock()
+
+ id, err := snapshot.MaxStreamPositionForReceipts(ctx)
if err != nil {
panic(err)
}
@@ -24,13 +30,15 @@ func (p *ReceiptStreamProvider) Setup() {
func (p *ReceiptStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *ReceiptStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
@@ -41,7 +49,7 @@ func (p *ReceiptStreamProvider) IncrementalSync(
}
}
- lastPos, receipts, err := p.DB.RoomReceiptsAfter(ctx, joinedRooms, from)
+ lastPos, receipts, err := snapshot.RoomReceiptsAfter(ctx, joinedRooms, from)
if err != nil {
req.Log.WithError(err).Error("p.DB.RoomReceiptsAfter failed")
return from
diff --git a/syncapi/streams/stream_sendtodevice.go b/syncapi/streams/stream_sendtodevice.go
index 31c6187c..00b67cc4 100644
--- a/syncapi/streams/stream_sendtodevice.go
+++ b/syncapi/streams/stream_sendtodevice.go
@@ -3,17 +3,23 @@ package streams
import (
"context"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type SendToDeviceStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
}
-func (p *SendToDeviceStreamProvider) Setup() {
- p.StreamProvider.Setup()
+func (p *SendToDeviceStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
+ p.DefaultStreamProvider.Setup(ctx, snapshot)
- id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background())
+ p.latestMutex.Lock()
+ defer p.latestMutex.Unlock()
+
+ id, err := snapshot.MaxStreamPositionForSendToDeviceMessages(ctx)
if err != nil {
panic(err)
}
@@ -22,18 +28,20 @@ func (p *SendToDeviceStreamProvider) Setup() {
func (p *SendToDeviceStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *SendToDeviceStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
// See if we have any new tasks to do for the send-to-device messaging.
- lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
+ lastPos, events, err := snapshot.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
if err != nil {
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
return from
diff --git a/syncapi/streams/stream_typing.go b/syncapi/streams/stream_typing.go
index f781065b..a6f7c7a0 100644
--- a/syncapi/streams/stream_typing.go
+++ b/syncapi/streams/stream_typing.go
@@ -5,24 +5,27 @@ import (
"encoding/json"
"github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type TypingStreamProvider struct {
- StreamProvider
+ DefaultStreamProvider
EDUCache *caching.EDUCache
}
func (p *TypingStreamProvider) CompleteSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
) types.StreamPosition {
- return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+ return p.IncrementalSync(ctx, snapshot, req, 0, p.LatestPosition(ctx))
}
func (p *TypingStreamProvider) IncrementalSync(
ctx context.Context,
+ snapshot storage.DatabaseTransaction,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
diff --git a/syncapi/streams/streamprovider.go b/syncapi/streams/streamprovider.go
new file mode 100644
index 00000000..8b12e2eb
--- /dev/null
+++ b/syncapi/streams/streamprovider.go
@@ -0,0 +1,28 @@
+package streams
+
+import (
+ "context"
+
+ "github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+type StreamProvider interface {
+ Setup(ctx context.Context, snapshot storage.DatabaseTransaction)
+
+ // Advance will update the latest position of the stream based on
+ // an update and will wake callers waiting on StreamNotifyAfter.
+ Advance(latest types.StreamPosition)
+
+ // CompleteSync will update the response to include all updates as needed
+ // for a complete sync. It will always return immediately.
+ CompleteSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest) types.StreamPosition
+
+ // IncrementalSync will update the response to include all updates between
+ // the from and to sync positions. It will always return immediately,
+ // making no changes if the range contains no updates.
+ IncrementalSync(ctx context.Context, snapshot storage.DatabaseTransaction, req *types.SyncRequest, from, to types.StreamPosition) types.StreamPosition
+
+ // LatestPosition returns the latest stream position for this stream.
+ LatestPosition(ctx context.Context) types.StreamPosition
+}
diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go
index dbc053bd..eccbb3a4 100644
--- a/syncapi/streams/streams.go
+++ b/syncapi/streams/streams.go
@@ -13,15 +13,15 @@ import (
)
type Streams struct {
- PDUStreamProvider types.StreamProvider
- TypingStreamProvider types.StreamProvider
- ReceiptStreamProvider types.StreamProvider
- InviteStreamProvider types.StreamProvider
- SendToDeviceStreamProvider types.StreamProvider
- AccountDataStreamProvider types.StreamProvider
- DeviceListStreamProvider types.StreamProvider
- NotificationDataStreamProvider types.StreamProvider
- PresenceStreamProvider types.StreamProvider
+ PDUStreamProvider StreamProvider
+ TypingStreamProvider StreamProvider
+ ReceiptStreamProvider StreamProvider
+ InviteStreamProvider StreamProvider
+ SendToDeviceStreamProvider StreamProvider
+ AccountDataStreamProvider StreamProvider
+ DeviceListStreamProvider StreamProvider
+ NotificationDataStreamProvider StreamProvider
+ PresenceStreamProvider StreamProvider
}
func NewSyncStreamProviders(
@@ -31,51 +31,58 @@ func NewSyncStreamProviders(
) *Streams {
streams := &Streams{
PDUStreamProvider: &PDUStreamProvider{
- StreamProvider: StreamProvider{DB: d},
- lazyLoadCache: lazyLoadCache,
- rsAPI: rsAPI,
- notifier: notifier,
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
+ lazyLoadCache: lazyLoadCache,
+ rsAPI: rsAPI,
+ notifier: notifier,
},
TypingStreamProvider: &TypingStreamProvider{
- StreamProvider: StreamProvider{DB: d},
- EDUCache: eduCache,
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
+ EDUCache: eduCache,
},
ReceiptStreamProvider: &ReceiptStreamProvider{
- StreamProvider: StreamProvider{DB: d},
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
InviteStreamProvider: &InviteStreamProvider{
- StreamProvider: StreamProvider{DB: d},
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
SendToDeviceStreamProvider: &SendToDeviceStreamProvider{
- StreamProvider: StreamProvider{DB: d},
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
AccountDataStreamProvider: &AccountDataStreamProvider{
- StreamProvider: StreamProvider{DB: d},
- userAPI: userAPI,
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
+ userAPI: userAPI,
},
NotificationDataStreamProvider: &NotificationDataStreamProvider{
- StreamProvider: StreamProvider{DB: d},
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
},
DeviceListStreamProvider: &DeviceListStreamProvider{
- StreamProvider: StreamProvider{DB: d},
- rsAPI: rsAPI,
- keyAPI: keyAPI,
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
+ rsAPI: rsAPI,
+ keyAPI: keyAPI,
},
PresenceStreamProvider: &PresenceStreamProvider{
- StreamProvider: StreamProvider{DB: d},
- notifier: notifier,
+ DefaultStreamProvider: DefaultStreamProvider{DB: d},
+ notifier: notifier,
},
}
- streams.PDUStreamProvider.Setup()
- streams.TypingStreamProvider.Setup()
- streams.ReceiptStreamProvider.Setup()
- streams.InviteStreamProvider.Setup()
- streams.SendToDeviceStreamProvider.Setup()
- streams.AccountDataStreamProvider.Setup()
- streams.NotificationDataStreamProvider.Setup()
- streams.DeviceListStreamProvider.Setup()
- streams.PresenceStreamProvider.Setup()
+ ctx := context.TODO()
+ snapshot, err := d.NewDatabaseSnapshot(ctx)
+ if err != nil {
+ panic(err)
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+
+ streams.PDUStreamProvider.Setup(ctx, snapshot)
+ streams.TypingStreamProvider.Setup(ctx, snapshot)
+ streams.ReceiptStreamProvider.Setup(ctx, snapshot)
+ streams.InviteStreamProvider.Setup(ctx, snapshot)
+ streams.SendToDeviceStreamProvider.Setup(ctx, snapshot)
+ streams.AccountDataStreamProvider.Setup(ctx, snapshot)
+ streams.NotificationDataStreamProvider.Setup(ctx, snapshot)
+ streams.DeviceListStreamProvider.Setup(ctx, snapshot)
+ streams.PresenceStreamProvider.Setup(ctx, snapshot)
return streams
}
diff --git a/syncapi/streams/template_stream.go b/syncapi/streams/template_stream.go
index 15074cc1..f208d84e 100644
--- a/syncapi/streams/template_stream.go
+++ b/syncapi/streams/template_stream.go
@@ -8,16 +8,18 @@ import (
"github.com/matrix-org/dendrite/syncapi/types"
)
-type StreamProvider struct {
+type DefaultStreamProvider struct {
DB storage.Database
latest types.StreamPosition
latestMutex sync.RWMutex
}
-func (p *StreamProvider) Setup() {
+func (p *DefaultStreamProvider) Setup(
+ ctx context.Context, snapshot storage.DatabaseTransaction,
+) {
}
-func (p *StreamProvider) Advance(
+func (p *DefaultStreamProvider) Advance(
latest types.StreamPosition,
) {
p.latestMutex.Lock()
@@ -28,7 +30,7 @@ func (p *StreamProvider) Advance(
}
}
-func (p *StreamProvider) LatestPosition(
+func (p *DefaultStreamProvider) LatestPosition(
ctx context.Context,
) types.StreamPosition {
p.latestMutex.RLock()
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index b2ea105f..1d0ac1a4 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -305,6 +305,13 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
}
+ snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
+ if err != nil {
+ logrus.WithError(err).Error("Failed to acquire database snapshot for sync request")
+ return jsonerror.InternalServerError()
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+
if syncReq.Since.IsEmpty() {
// Complete sync
syncReq.Response.NextBatch = types.StreamingToken{
@@ -312,70 +319,70 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// might advance while processing other streams, resulting in flakey
// tests.
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
PDUPosition: rp.streams.PDUStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
TypingPosition: rp.streams.TypingStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
ReceiptPosition: rp.streams.ReceiptStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
InvitePosition: rp.streams.InviteStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
PresencePosition: rp.streams.PresenceStreamProvider.CompleteSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
),
}
} else {
// Incremental sync
syncReq.Response.NextBatch = types.StreamingToken{
PDUPosition: rp.streams.PDUStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.PDUPosition, currentPos.PDUPosition,
),
TypingPosition: rp.streams.TypingStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.TypingPosition, currentPos.TypingPosition,
),
ReceiptPosition: rp.streams.ReceiptStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.ReceiptPosition, currentPos.ReceiptPosition,
),
InvitePosition: rp.streams.InviteStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.InvitePosition, currentPos.InvitePosition,
),
SendToDevicePosition: rp.streams.SendToDeviceStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.SendToDevicePosition, currentPos.SendToDevicePosition,
),
AccountDataPosition: rp.streams.AccountDataStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
),
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
),
PresencePosition: rp.streams.PresenceStreamProvider.IncrementalSync(
- syncReq.Context, syncReq,
+ syncReq.Context, snapshot, syncReq,
syncReq.Since.PresencePosition, currentPos.PresencePosition,
),
}
@@ -437,9 +444,15 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
util.GetLogger(req.Context()).WithError(err).Error("newSyncRequest failed")
return jsonerror.InternalServerError()
}
- rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), syncReq, fromToken.PDUPosition, toToken.PDUPosition)
+ snapshot, err := rp.db.NewDatabaseSnapshot(req.Context())
+ if err != nil {
+ logrus.WithError(err).Error("Failed to acquire database snapshot for key change")
+ return jsonerror.InternalServerError()
+ }
+ defer snapshot.Rollback() // nolint:errcheck
+ rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup(
- req.Context(), rp.db, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
+ req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
)
if err != nil {
diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go
index a9ea234d..378cafe9 100644
--- a/syncapi/types/provider.go
+++ b/syncapi/types/provider.go
@@ -41,23 +41,3 @@ func (r *SyncRequest) IsRoomPresent(roomID string) bool {
return false
}
}
-
-type StreamProvider interface {
- Setup()
-
- // Advance will update the latest position of the stream based on
- // an update and will wake callers waiting on StreamNotifyAfter.
- Advance(latest StreamPosition)
-
- // CompleteSync will update the response to include all updates as needed
- // for a complete sync. It will always return immediately.
- CompleteSync(ctx context.Context, req *SyncRequest) StreamPosition
-
- // IncrementalSync will update the response to include all updates between
- // the from and to sync positions. It will always return immediately,
- // making no changes if the range contains no updates.
- IncrementalSync(ctx context.Context, req *SyncRequest, from, to StreamPosition) StreamPosition
-
- // LatestPosition returns the latest stream position for this stream.
- LatestPosition(ctx context.Context) StreamPosition
-}