aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/postgres/syncserver.go
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/storage/postgres/syncserver.go')
-rw-r--r--syncapi/storage/postgres/syncserver.go400
1 files changed, 298 insertions, 102 deletions
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index 3a62d136..621aec95 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -20,7 +20,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
- "strconv"
"time"
"github.com/sirupsen/logrus"
@@ -43,29 +42,24 @@ type stateDelta struct {
membership string
// The PDU stream position of the latest membership event for this user, if applicable.
// Can be 0 if there is no membership event in this delta.
- membershipPos int64
+ membershipPos types.StreamPosition
}
-// Same as gomatrixserverlib.Event but also has the PDU stream position for this event.
-type streamEvent struct {
- gomatrixserverlib.Event
- streamPosition int64
- transactionID *api.TransactionID
-}
-
-// SyncServerDatabase represents a sync server datasource which manages
+// SyncServerDatasource represents a sync server datasource which manages
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
db *sql.DB
common.PartitionOffsetStatements
- accountData accountDataStatements
- events outputRoomEventsStatements
- roomstate currentRoomStateStatements
- invites inviteEventsStatements
- typingCache *cache.TypingCache
+ accountData accountDataStatements
+ events outputRoomEventsStatements
+ roomstate currentRoomStateStatements
+ invites inviteEventsStatements
+ typingCache *cache.TypingCache
+ topology outputRoomEventsTopologyStatements
+ backwardExtremities backwardExtremitiesStatements
}
-// NewSyncServerDatabase creates a new sync server database
+// NewSyncServerDatasource creates a new sync server database
func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, error) {
var d SyncServerDatasource
var err error
@@ -87,6 +81,12 @@ func NewSyncServerDatasource(dbDataSourceName string) (*SyncServerDatasource, er
if err := d.invites.prepare(d.db); err != nil {
return nil, err
}
+ if err := d.topology.prepare(d.db); err != nil {
+ return nil, err
+ }
+ if err := d.backwardExtremities.prepare(d.db); err != nil {
+ return nil, err
+ }
d.typingCache = cache.NewTypingCache()
return &d, nil
}
@@ -109,7 +109,46 @@ func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([
// We don't include a device here as we only include transaction IDs in
// incremental syncs.
- return streamEventsToEvents(nil, streamEvents), nil
+ return d.StreamEventsToEvents(nil, streamEvents), nil
+}
+
+func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, ev *gomatrixserverlib.Event) error {
+ // If the event is already known as a backward extremity, don't consider
+ // it as such anymore now that we have it.
+ isBackwardExtremity, err := d.backwardExtremities.isBackwardExtremity(ctx, ev.RoomID(), ev.EventID())
+ if err != nil {
+ return err
+ }
+ if isBackwardExtremity {
+ if err = d.backwardExtremities.deleteBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+
+ // Check if we have all of the event's previous events. If an event is
+ // missing, add it to the room's backward extremities.
+ prevEvents, err := d.events.selectEvents(ctx, nil, ev.PrevEventIDs())
+ if err != nil {
+ return err
+ }
+ var found bool
+ for _, eID := range ev.PrevEventIDs() {
+ found = false
+ for _, prevEv := range prevEvents {
+ if eID == prevEv.EventID() {
+ found = true
+ }
+ }
+
+ // If the event is missing, consider it a backward extremity.
+ if !found {
+ if err = d.backwardExtremities.insertsBackwardExtremity(ctx, ev.RoomID(), ev.EventID()); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
}
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
@@ -120,16 +159,26 @@ func (d *SyncServerDatasource) WriteEvent(
ev *gomatrixserverlib.Event,
addStateEvents []gomatrixserverlib.Event,
addStateEventIDs, removeStateEventIDs []string,
- transactionID *api.TransactionID,
-) (pduPosition int64, returnErr error) {
+ transactionID *api.TransactionID, excludeFromSync bool,
+) (pduPosition types.StreamPosition, returnErr error) {
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
- pos, err := d.events.insertEvent(ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID)
+ pos, err := d.events.insertEvent(
+ ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
+ )
if err != nil {
return err
}
pduPosition = pos
+ if err = d.topology.insertEventInTopology(ctx, ev); err != nil {
+ return err
+ }
+
+ if err = d.handleBackwardExtremities(ctx, ev); err != nil {
+ return err
+ }
+
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
// Nothing to do, the event may have just been a message event.
return nil
@@ -137,14 +186,15 @@ func (d *SyncServerDatasource) WriteEvent(
return d.updateRoomState(ctx, txn, removeStateEventIDs, addStateEvents, pduPosition)
})
- return
+
+ return pduPosition, returnErr
}
func (d *SyncServerDatasource) updateRoomState(
ctx context.Context, txn *sql.Tx,
removedEventIDs []string,
addedEvents []gomatrixserverlib.Event,
- pduPosition int64,
+ pduPosition types.StreamPosition,
) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs {
@@ -196,14 +246,141 @@ func (d *SyncServerDatasource) GetStateEventsForRoom(
return
}
+// GetEventsInRange retrieves all of the events on a given ordering using the
+// given extremities and limit.
+func (d *SyncServerDatasource) GetEventsInRange(
+ ctx context.Context,
+ from, to *types.PaginationToken,
+ roomID string, limit int,
+ backwardOrdering bool,
+) (events []types.StreamEvent, err error) {
+ // If the pagination token's type is types.PaginationTokenTypeTopology, the
+ // events must be retrieved from the rooms' topology table rather than the
+ // table contaning the syncapi server's whole stream of events.
+ if from.Type == types.PaginationTokenTypeTopology {
+ // Determine the backward and forward limit, i.e. the upper and lower
+ // limits to the selection in the room's topology, from the direction.
+ var backwardLimit, forwardLimit types.StreamPosition
+ if backwardOrdering {
+ // Backward ordering is antichronological (latest event to oldest
+ // one).
+ backwardLimit = to.PDUPosition
+ forwardLimit = from.PDUPosition
+ } else {
+ // Forward ordering is chronological (oldest event to latest one).
+ backwardLimit = from.PDUPosition
+ forwardLimit = to.PDUPosition
+ }
+
+ // Select the event IDs from the defined range.
+ var eIDs []string
+ eIDs, err = d.topology.selectEventIDsInRange(
+ ctx, roomID, backwardLimit, forwardLimit, limit, !backwardOrdering,
+ )
+ if err != nil {
+ return
+ }
+
+ // Retrieve the events' contents using their IDs.
+ events, err = d.events.selectEvents(ctx, nil, eIDs)
+ return
+ }
+
+ // If the pagination token's type is types.PaginationTokenTypeStream, the
+ // events must be retrieved from the table contaning the syncapi server's
+ // whole stream of events.
+
+ if backwardOrdering {
+ // When using backward ordering, we want the most recent events first.
+ if events, err = d.events.selectRecentEvents(
+ ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false,
+ ); err != nil {
+ return
+ }
+ } else {
+ // When using forward ordering, we want the least recent events first.
+ if events, err = d.events.selectEarlyEvents(
+ ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit,
+ ); err != nil {
+ return
+ }
+ }
+
+ return
+}
+
// SyncPosition returns the latest positions for syncing.
-func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.SyncPosition, error) {
+func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) {
return d.syncPositionTx(ctx, nil)
}
+// BackwardExtremitiesForRoom returns the event IDs of all of the backward
+// extremities we know of for a given room.
+func (d *SyncServerDatasource) BackwardExtremitiesForRoom(
+ ctx context.Context, roomID string,
+) (backwardExtremities []string, err error) {
+ return d.backwardExtremities.selectBackwardExtremitiesForRoom(ctx, roomID)
+}
+
+// MaxTopologicalPosition returns the highest topological position for a given
+// room.
+func (d *SyncServerDatasource) MaxTopologicalPosition(
+ ctx context.Context, roomID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectMaxPositionInTopology(ctx, roomID)
+}
+
+// EventsAtTopologicalPosition returns all of the events matching a given
+// position in the topology of a given room.
+func (d *SyncServerDatasource) EventsAtTopologicalPosition(
+ ctx context.Context, roomID string, pos types.StreamPosition,
+) ([]types.StreamEvent, error) {
+ eIDs, err := d.topology.selectEventIDsFromPosition(ctx, roomID, pos)
+ if err != nil {
+ return nil, err
+ }
+
+ return d.events.selectEvents(ctx, nil, eIDs)
+}
+
+func (d *SyncServerDatasource) EventPositionInTopology(
+ ctx context.Context, eventID string,
+) (types.StreamPosition, error) {
+ return d.topology.selectPositionInTopology(ctx, eventID)
+}
+
+// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
+func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) {
+ return d.syncStreamPositionTx(ctx, nil)
+}
+
+func (d *SyncServerDatasource) syncStreamPositionTx(
+ ctx context.Context, txn *sql.Tx,
+) (types.StreamPosition, error) {
+ maxID, err := d.events.selectMaxEventID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxAccountDataID > maxID {
+ maxID = maxAccountDataID
+ }
+ maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ if maxInviteID > maxID {
+ maxID = maxInviteID
+ }
+ return types.StreamPosition(maxID), nil
+}
+
func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx,
-) (sp types.SyncPosition, err error) {
+) (sp types.PaginationToken, err error) {
maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil {
@@ -223,10 +400,8 @@ func (d *SyncServerDatasource) syncPositionTx(
if maxInviteID > maxEventID {
maxEventID = maxInviteID
}
- sp.PDUPosition = maxEventID
-
- sp.TypingPosition = d.typingCache.GetLatestSyncPosition()
-
+ sp.PDUPosition = types.StreamPosition(maxEventID)
+ sp.EDUTypingPosition = types.StreamPosition(d.typingCache.GetLatestSyncPosition())
return
}
@@ -235,7 +410,7 @@ func (d *SyncServerDatasource) syncPositionTx(
func (d *SyncServerDatasource) addPDUDeltaToResponse(
ctx context.Context,
device authtypes.Device,
- fromPos, toPos int64,
+ fromPos, toPos types.StreamPosition,
numRecentEventsPerRoom int,
wantFullState bool,
res *types.Response,
@@ -287,7 +462,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse(
- since int64,
+ since types.PaginationToken,
joinedRoomIDs []string,
res *types.Response,
) error {
@@ -296,7 +471,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
var err error
for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.typingCache.GetTypingUsersIfUpdatedAfter(
- roomID, since,
+ roomID, int64(since.EDUTypingPosition),
); updated {
ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping,
@@ -321,14 +496,14 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse(
- fromPos, toPos types.SyncPosition,
+ fromPos, toPos types.PaginationToken,
joinedRoomIDs []string,
res *types.Response,
) (err error) {
- if fromPos.TypingPosition != toPos.TypingPosition {
+ if fromPos.EDUTypingPosition != toPos.EDUTypingPosition {
err = d.addTypingDeltaToResponse(
- fromPos.TypingPosition, joinedRoomIDs, res,
+ fromPos, joinedRoomIDs, res,
)
}
@@ -343,7 +518,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse(
func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context,
device authtypes.Device,
- fromPos, toPos types.SyncPosition,
+ fromPos, toPos types.PaginationToken,
numRecentEventsPerRoom int,
wantFullState bool,
) (*types.Response, error) {
@@ -383,7 +558,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom int,
) (
res *types.Response,
- toPos types.SyncPosition,
+ toPos types.PaginationToken,
joinedRoomIDs []string,
err error,
) {
@@ -423,27 +598,37 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
}
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
- var recentStreamEvents []streamEvent
+ var recentStreamEvents []types.StreamEvent
recentStreamEvents, err = d.events.selectRecentEvents(
- ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom,
+ ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition,
+ numRecentEventsPerRoom, true, true,
+ //ctx, txn, roomID, 0, toPos.PDUPosition, numRecentEventsPerRoom,
)
if err != nil {
return
}
+ // Retrieve the backward topology position, i.e. the position of the
+ // oldest event in the room's topology.
+ var backwardTopologyPos types.StreamPosition
+ backwardTopologyPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID())
+ if err != nil {
+ return nil, types.PaginationToken{}, []string{}, err
+ }
+ if backwardTopologyPos-1 <= 0 {
+ backwardTopologyPos = types.StreamPosition(1)
+ } else {
+ backwardTopologyPos = backwardTopologyPos - 1
+ }
+
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs
- recentEvents := streamEventsToEvents(nil, recentStreamEvents)
-
+ recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse()
- if prevPDUPos := recentStreamEvents[0].streamPosition - 1; prevPDUPos > 0 {
- // Use the short form of batch token for prev_batch
- jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
- } else {
- // Use the short form of batch token for prev_batch
- jr.Timeline.PrevBatch = "1"
- }
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true
jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
@@ -471,7 +656,7 @@ func (d *SyncServerDatasource) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse(
- types.SyncPosition{}, toPos, joinedRoomIDs, res,
+ types.PaginationToken{}, toPos, joinedRoomIDs, res,
)
if err != nil {
return nil, err
@@ -496,7 +681,7 @@ var txReadOnlySnapshot = sql.TxOptions{
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
func (d *SyncServerDatasource) GetAccountDataInRange(
- ctx context.Context, userID string, oldPos, newPos int64,
+ ctx context.Context, userID string, oldPos, newPos types.StreamPosition,
accountDataFilterPart *gomatrix.FilterPart,
) (map[string][]string, error) {
return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart)
@@ -510,7 +695,7 @@ func (d *SyncServerDatasource) GetAccountDataInRange(
// Returns an error if there was an issue with the upsert
func (d *SyncServerDatasource) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
-) (int64, error) {
+) (types.StreamPosition, error) {
return d.accountData.insertAccountData(ctx, userID, roomID, dataType)
}
@@ -519,7 +704,7 @@ func (d *SyncServerDatasource) UpsertAccountData(
// Returns an error if there was a problem communicating with the database.
func (d *SyncServerDatasource) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.Event,
-) (int64, error) {
+) (types.StreamPosition, error) {
return d.invites.insertInviteEvent(ctx, inviteEvent)
}
@@ -542,26 +727,26 @@ func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallback
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) AddTypingUser(
userID, roomID string, expireTime *time.Time,
-) int64 {
- return d.typingCache.AddTypingUser(userID, roomID, expireTime)
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.AddTypingUser(userID, roomID, expireTime))
}
// RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications.
func (d *SyncServerDatasource) RemoveTypingUser(
userID, roomID string,
-) int64 {
- return d.typingCache.RemoveUser(userID, roomID)
+) types.StreamPosition {
+ return types.StreamPosition(d.typingCache.RemoveUser(userID, roomID))
}
func (d *SyncServerDatasource) addInvitesToResponse(
ctx context.Context, txn *sql.Tx,
userID string,
- fromPos, toPos int64,
+ fromPos, toPos types.StreamPosition,
res *types.Response,
) error {
invites, err := d.invites.selectInviteEventsInRange(
- ctx, txn, userID, int64(fromPos), int64(toPos),
+ ctx, txn, userID, fromPos, toPos,
)
if err != nil {
return err
@@ -577,12 +762,32 @@ func (d *SyncServerDatasource) addInvitesToResponse(
return nil
}
+// Retrieve the backward topology position, i.e. the position of the
+// oldest event in the room's topology.
+func (d *SyncServerDatasource) getBackwardTopologyPos(
+ ctx context.Context,
+ events []types.StreamEvent,
+) (pos types.StreamPosition, err error) {
+ if len(events) > 0 {
+ pos, err = d.topology.selectPositionInTopology(ctx, events[0].EventID())
+ if err != nil {
+ return
+ }
+ }
+ if pos-1 <= 0 {
+ pos = types.StreamPosition(1)
+ } else {
+ pos = pos - 1
+ }
+ return
+}
+
// addRoomDeltaToResponse adds a room state delta to a sync response
func (d *SyncServerDatasource) addRoomDeltaToResponse(
ctx context.Context,
device *authtypes.Device,
txn *sql.Tx,
- fromPos, toPos int64,
+ fromPos, toPos types.StreamPosition,
delta stateDelta,
numRecentEventsPerRoom int,
res *types.Response,
@@ -598,38 +803,28 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
endPos = delta.membershipPos
}
recentStreamEvents, err := d.events.selectRecentEvents(
- ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom,
+ ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos),
+ numRecentEventsPerRoom, true, true,
)
if err != nil {
return err
}
- recentEvents := streamEventsToEvents(device, recentStreamEvents)
+ recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
- var prevPDUPos int64
-
- if len(recentEvents) == 0 {
- if len(delta.stateEvents) == 0 {
- // Don't bother appending empty room entries
- return nil
- }
-
- // If full_state=true and since is already up to date, then we'll have
- // state events but no recent events.
- prevPDUPos = toPos - 1
- } else {
- prevPDUPos = recentStreamEvents[0].streamPosition - 1
- }
-
- if prevPDUPos <= 0 {
- prevPDUPos = 1
+ var backwardTopologyPos types.StreamPosition
+ backwardTopologyPos, err = d.getBackwardTopologyPos(ctx, recentStreamEvents)
+ if err != nil {
+ return err
}
switch delta.membership {
case gomatrixserverlib.Join:
jr := types.NewJoinResponse()
- // Use the short form of batch token for prev_batch
- jr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
+
+ jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@@ -640,8 +835,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room.
lr := types.NewLeaveResponse()
- // Use the short form of batch token for prev_batch
- lr.Timeline.PrevBatch = strconv.FormatInt(prevPDUPos, 10)
+ lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
+ types.PaginationTokenTypeTopology, backwardTopologyPos, 0,
+ ).String()
lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@@ -656,9 +852,9 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
func (d *SyncServerDatasource) fetchStateEvents(
ctx context.Context, txn *sql.Tx,
roomIDToEventIDSet map[string]map[string]bool,
- eventIDToEvent map[string]streamEvent,
-) (map[string][]streamEvent, error) {
- stateBetween := make(map[string][]streamEvent)
+ eventIDToEvent map[string]types.StreamEvent,
+) (map[string][]types.StreamEvent, error) {
+ stateBetween := make(map[string][]types.StreamEvent)
missingEvents := make(map[string][]string)
for roomID, ids := range roomIDToEventIDSet {
events := stateBetween[roomID]
@@ -700,7 +896,7 @@ func (d *SyncServerDatasource) fetchStateEvents(
func (d *SyncServerDatasource) fetchMissingStateEvents(
ctx context.Context, txn *sql.Tx, eventIDs []string,
-) ([]streamEvent, error) {
+) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the
// event.
events, err := d.events.selectEvents(ctx, txn, eventIDs)
@@ -743,7 +939,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents(
// A list of joined room IDs is also returned in case the caller needs it.
func (d *SyncServerDatasource) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx,
- fromPos, toPos int64, userID string,
+ fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
@@ -776,7 +972,7 @@ func (d *SyncServerDatasource) getStateDeltas(
if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" {
if membership == gomatrixserverlib.Join {
// send full room state down instead of a delta
- var s []streamEvent
+ var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilterPart)
if err != nil {
return nil, nil, err
@@ -787,8 +983,8 @@ func (d *SyncServerDatasource) getStateDeltas(
deltas = append(deltas, stateDelta{
membership: membership,
- membershipPos: ev.streamPosition,
- stateEvents: streamEventsToEvents(device, stateStreamEvents),
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
roomID: roomID,
})
break
@@ -804,7 +1000,7 @@ func (d *SyncServerDatasource) getStateDeltas(
for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join,
- stateEvents: streamEventsToEvents(device, state[joinedRoomID]),
+ stateEvents: d.StreamEventsToEvents(device, state[joinedRoomID]),
roomID: joinedRoomID,
})
}
@@ -818,7 +1014,7 @@ func (d *SyncServerDatasource) getStateDeltas(
// updates for other rooms.
func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
ctx context.Context, device *authtypes.Device, txn *sql.Tx,
- fromPos, toPos int64, userID string,
+ fromPos, toPos types.StreamPosition, userID string,
stateFilterPart *gomatrix.FilterPart,
) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@@ -837,7 +1033,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
}
deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Join,
- stateEvents: streamEventsToEvents(device, s),
+ stateEvents: d.StreamEventsToEvents(device, s),
roomID: joinedRoomID,
})
}
@@ -858,8 +1054,8 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
if membership != gomatrixserverlib.Join { // We've already added full state for all joined rooms above.
deltas = append(deltas, stateDelta{
membership: membership,
- membershipPos: ev.streamPosition,
- stateEvents: streamEventsToEvents(device, stateStreamEvents),
+ membershipPos: ev.StreamPosition,
+ stateEvents: d.StreamEventsToEvents(device, stateStreamEvents),
roomID: roomID,
})
}
@@ -875,29 +1071,29 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync(
func (d *SyncServerDatasource) currentStateStreamEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrix.FilterPart,
-) ([]streamEvent, error) {
+) ([]types.StreamEvent, error) {
allState, err := d.roomstate.selectCurrentState(ctx, txn, roomID, stateFilterPart)
if err != nil {
return nil, err
}
- s := make([]streamEvent, len(allState))
+ s := make([]types.StreamEvent, len(allState))
for i := 0; i < len(s); i++ {
- s[i] = streamEvent{Event: allState[i], streamPosition: 0}
+ s[i] = types.StreamEvent{Event: allState[i], StreamPosition: 0}
}
return s, nil
}
-// streamEventsToEvents converts streamEvent to Event. If device is non-nil and
+// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
-func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event {
+func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.Event {
out := make([]gomatrixserverlib.Event, len(in))
for i := 0; i < len(in); i++ {
out[i] = in[i].Event
- if device != nil && in[i].transactionID != nil {
- if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID {
+ if device != nil && in[i].TransactionID != nil {
+ if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID {
err := out[i].SetUnsignedField(
- "transaction_id", in[i].transactionID.TransactionID,
+ "transaction_id", in[i].TransactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{