aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-07-13 14:18:37 +0200
committerGitHub <noreply@github.com>2023-07-13 14:18:37 +0200
commitf12982472c71b8daf3de682c2807989ee695d2cf (patch)
tree16dce5247f0b7cc2c9416b68edf3bfd212079d87 /syncapi
parent0df982a2e50021183fa478d99b2e463d512ff230 (diff)
Tweaks around `/messages` (#3149)
Try to mitigate some issues with `/messages`
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/routing/messages.go105
-rw-r--r--syncapi/routing/routing.go5
-rw-r--r--syncapi/storage/interface.go7
-rw-r--r--syncapi/storage/postgres/output_room_events_topology_table.go24
-rw-r--r--syncapi/storage/shared/storage_sync.go8
-rw-r--r--syncapi/storage/sqlite3/output_room_events_topology_table.go23
-rw-r--r--syncapi/storage/storage_test.go38
-rw-r--r--syncapi/storage/tables/interface.go6
-rw-r--r--syncapi/storage/tables/topology_test.go42
-rw-r--r--syncapi/syncapi.go3
-rw-r--r--syncapi/syncapi_test.go1
11 files changed, 178 insertions, 84 deletions
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index c3871618..23a09544 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -53,6 +53,7 @@ type messagesReq struct {
wasToProvided bool
backwardOrdering bool
filter *synctypes.RoomEventFilter
+ didBackfill bool
}
type messagesResp struct {
@@ -251,18 +252,19 @@ func OnIncomingMessagesRequest(
}
// If start and end are equal, we either reached the beginning or something else
- // is wrong. To avoid endless loops from clients, set end to 0 an empty string
- if start == end {
+ // is wrong. If we have nothing to return set end to 0.
+ if start == end || len(clientEvents) == 0 {
end = types.TopologyToken{}
}
util.GetLogger(req.Context()).WithFields(logrus.Fields{
- "from": from.String(),
- "to": to.String(),
- "limit": filter.Limit,
- "backwards": backwardOrdering,
- "return_start": start.String(),
- "return_end": end.String(),
+ "request_from": from.String(),
+ "request_to": to.String(),
+ "limit": filter.Limit,
+ "backwards": backwardOrdering,
+ "response_start": start.String(),
+ "response_end": end.String(),
+ "backfilled": mReq.didBackfill,
}).Info("Responding")
res := messagesResp{
@@ -284,11 +286,6 @@ func OnIncomingMessagesRequest(
})...)
}
- // If we didn't return any events, set the end to an empty string, so it will be omitted
- // in the response JSON.
- if len(res.Chunk) == 0 {
- res.End = ""
- }
if fromStream != nil {
res.StartStream = fromStream.String()
}
@@ -328,11 +325,12 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
) {
emptyToken := types.TopologyToken{}
// Retrieve the events from the local database.
- streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
+ streamEvents, _, end, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err)
- return []synctypes.ClientEvent{}, emptyToken, emptyToken, err
+ return []synctypes.ClientEvent{}, *r.from, emptyToken, err
}
+ end.Decrement()
var events []*rstypes.HeaderedEvent
util.GetLogger(r.ctx).WithFields(logrus.Fields{
@@ -346,32 +344,54 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
// on the ordering), or we've reached a backward extremity.
if len(streamEvents) == 0 {
if events, err = r.handleEmptyEventsSlice(); err != nil {
- return []synctypes.ClientEvent{}, emptyToken, emptyToken, err
+ return []synctypes.ClientEvent{}, *r.from, emptyToken, err
}
} else {
if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil {
- return []synctypes.ClientEvent{}, emptyToken, emptyToken, err
+ return []synctypes.ClientEvent{}, *r.from, emptyToken, err
}
}
// If we didn't get any event, we don't need to proceed any further.
if len(events) == 0 {
- return []synctypes.ClientEvent{}, *r.from, *r.to, nil
+ return []synctypes.ClientEvent{}, *r.from, emptyToken, nil
}
- // Get the position of the first and the last event in the room's topology.
- // This position is currently determined by the event's depth, so we could
- // also use it instead of retrieving from the database. However, if we ever
- // change the way topological positions are defined (as depth isn't the most
- // reliable way to define it), it would be easier and less troublesome to
- // only have to change it in one place, i.e. the database.
- start, end, err = r.getStartEnd(events)
+ // Apply room history visibility filter
+ startTime := time.Now()
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
if err != nil {
- return []synctypes.ClientEvent{}, *r.from, *r.to, err
+ return []synctypes.ClientEvent{}, *r.from, *r.to, nil
+ }
+ logrus.WithFields(logrus.Fields{
+ "duration": time.Since(startTime),
+ "room_id": r.roomID,
+ "events_before": len(events),
+ "events_after": len(filteredEvents),
+ }).Debug("applied history visibility (messages)")
+
+ // No events left after applying history visibility
+ if len(filteredEvents) == 0 {
+ return []synctypes.ClientEvent{}, *r.from, emptyToken, nil
+ }
+
+ // If we backfilled in the process of getting events, we need
+ // to re-fetch the start/end positions
+ if r.didBackfill {
+ _, end, err = r.getStartEnd(filteredEvents)
+ if err != nil {
+ return []synctypes.ClientEvent{}, *r.from, *r.to, err
+ }
}
// Sort the events to ensure we send them in the right order.
if r.backwardOrdering {
+ if events[len(events)-1].Type() == spec.MRoomCreate {
+ // NOTSPEC: We've hit the beginning of the room so there's really nowhere
+ // else to go. This seems to fix Element iOS from looping on /messages endlessly.
+ end = types.TopologyToken{}
+ }
+
// This reverses the array from old->new to new->old
reversed := func(in []*rstypes.HeaderedEvent) []*rstypes.HeaderedEvent {
out := make([]*rstypes.HeaderedEvent, len(in))
@@ -380,24 +400,14 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
}
return out
}
- events = reversed(events)
- }
- if len(events) == 0 {
- return []synctypes.ClientEvent{}, *r.from, *r.to, nil
+ filteredEvents = reversed(filteredEvents)
}
- // Apply room history visibility filter
- startTime := time.Now()
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
- logrus.WithFields(logrus.Fields{
- "duration": time.Since(startTime),
- "room_id": r.roomID,
- "events_before": len(events),
- "events_after": len(filteredEvents),
- }).Debug("applied history visibility (messages)")
+ start = *r.from
+
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
- }), start, end, err
+ }), start, end, nil
}
func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) {
@@ -450,6 +460,7 @@ func (r *messagesReq) handleEmptyEventsSlice() (
if err != nil {
return
}
+ r.didBackfill = true
} else {
// If not, it means the slice was empty because we reached the room's
// creation, so return an empty slice.
@@ -499,7 +510,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
if err != nil {
return
}
-
+ r.didBackfill = true
// Append the PDUs to the list to send back to the client.
events = append(events, pdus...)
}
@@ -561,15 +572,17 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
if res.HistoryVisibility == "" {
res.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared
}
- for i := range res.Events {
+ events := res.Events
+ for i := range events {
+ events[i].Visibility = res.HistoryVisibility
_, err = r.db.WriteEvent(
context.Background(),
- res.Events[i],
+ events[i],
[]*rstypes.HeaderedEvent{},
[]string{},
[]string{},
nil, true,
- res.HistoryVisibility,
+ events[i].Visibility,
)
if err != nil {
return nil, err
@@ -577,14 +590,10 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
}
// we may have got more than the requested limit so resize now
- events := res.Events
if len(events) > limit {
// last `limit` events
events = events[len(events)-limit:]
}
- for _, ev := range events {
- ev.Visibility = res.HistoryVisibility
- }
return events, nil
}
diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go
index 8542c0b7..a837e169 100644
--- a/syncapi/routing/routing.go
+++ b/syncapi/routing/routing.go
@@ -43,6 +43,7 @@ func Setup(
cfg *config.SyncAPI,
lazyLoadCache caching.LazyLoadCache,
fts fulltext.Indexer,
+ rateLimits *httputil.RateLimits,
) {
v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter()
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
@@ -53,6 +54,10 @@ func Setup(
}, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ // not specced, but ensure we're rate limiting requests to this endpoint
+ if r := rateLimits.Limit(req, device); r != nil {
+ return *r
+ }
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 243b2592..dca5d1a1 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -81,8 +81,11 @@ type DatabaseTransaction interface {
// If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *synctypes.EventFilter) (map[string][]string, types.StreamPosition, error)
- // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
- GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
+ // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
+ // If backwardsOrdering is true, the most recent event must be first, else last.
+ // Returns the filtered StreamEvents on success. Returns **unfiltered** StreamEvents and ErrNoEventsForFilter if
+ // the provided filter removed all events, this can be used to still calculate the start/end position. (e.g for `/messages`)
+ GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, start, end types.TopologyToken, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go
index 7140a92f..b281f330 100644
--- a/syncapi/storage/postgres/output_room_events_topology_table.go
+++ b/syncapi/storage/postgres/output_room_events_topology_table.go
@@ -48,14 +48,14 @@ const insertEventInTopologySQL = "" +
" RETURNING topological_position"
const selectEventIDsInRangeASCSQL = "" +
- "SELECT event_id FROM syncapi_output_room_events_topology" +
+ "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position >= $5)" +
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" +
- "SELECT event_id FROM syncapi_output_room_events_topology" +
+ "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
@@ -113,12 +113,13 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
}
// SelectEventIDsInRange selects the IDs of events which positions are within a
-// given range in a given room's topological order.
+// given range in a given room's topological order. Returns the start/end topological tokens for
+// the returned eventIDs.
// Returns an empty slice if no events match the given range.
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition,
limit int, chronologicalOrder bool,
-) (eventIDs []string, err error) {
+) (eventIDs []string, start, end types.TopologyToken, err error) {
// Decide on the selection's order according to whether chronological order
// is requested or not.
var stmt *sql.Stmt
@@ -132,7 +133,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
if err == sql.ErrNoRows {
// If no event matched the request, return an empty slice.
- return []string{}, nil
+ return []string{}, start, end, nil
} else if err != nil {
return
}
@@ -140,14 +141,23 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
// Return the IDs.
var eventID string
+ var token types.TopologyToken
+ var tokens []types.TopologyToken
for rows.Next() {
- if err = rows.Scan(&eventID); err != nil {
+ if err = rows.Scan(&eventID, &token.Depth, &token.PDUPosition); err != nil {
return
}
eventIDs = append(eventIDs, eventID)
+ tokens = append(tokens, token)
}
- return eventIDs, rows.Err()
+ // The values are already ordered by SQL, so we can use them as is.
+ if len(tokens) > 0 {
+ start = tokens[0]
+ end = tokens[len(tokens)-1]
+ }
+
+ return eventIDs, start, end, rows.Err()
}
// SelectPositionInTopology returns the position of a given event in the
diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go
index 8e79b71d..cd17fdc6 100644
--- a/syncapi/storage/shared/storage_sync.go
+++ b/syncapi/storage/shared/storage_sync.go
@@ -237,7 +237,7 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange(
roomID string,
filter *synctypes.RoomEventFilter,
backwardOrdering bool,
-) (events []types.StreamEvent, err error) {
+) (events []types.StreamEvent, start, end types.TopologyToken, err error) {
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
if backwardOrdering {
// Backward ordering means the 'from' token has a higher depth than the 'to' token
@@ -255,7 +255,7 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange(
// Select the event IDs from the defined range.
var eIDs []string
- eIDs, err = d.Topology.SelectEventIDsInRange(
+ eIDs, start, end, err = d.Topology.SelectEventIDsInRange(
ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
)
if err != nil {
@@ -264,6 +264,10 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange(
// Retrieve the events' contents using their IDs.
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true)
+ if err != nil {
+ return
+ }
+
return
}
diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go
index 68b75f5b..614e1df9 100644
--- a/syncapi/storage/sqlite3/output_room_events_topology_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go
@@ -44,14 +44,14 @@ const insertEventInTopologySQL = "" +
" ON CONFLICT DO NOTHING"
const selectEventIDsInRangeASCSQL = "" +
- "SELECT event_id FROM syncapi_output_room_events_topology" +
+ "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position >= $5)" +
") ORDER BY topological_position ASC, stream_position ASC LIMIT $6"
const selectEventIDsInRangeDESCSQL = "" +
- "SELECT event_id FROM syncapi_output_room_events_topology" +
+ "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 AND (" +
"(topological_position > $2 AND topological_position < $3) OR" +
"(topological_position = $4 AND stream_position <= $5)" +
@@ -111,11 +111,15 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
return types.StreamPosition(event.Depth()), err
}
+// SelectEventIDsInRange selects the IDs of events which positions are within a
+// given range in a given room's topological order. Returns the start/end topological tokens for
+// the returned eventIDs.
+// Returns an empty slice if no events match the given range.
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
ctx context.Context, txn *sql.Tx, roomID string,
minDepth, maxDepth, maxStreamPos types.StreamPosition,
limit int, chronologicalOrder bool,
-) (eventIDs []string, err error) {
+) (eventIDs []string, start, end types.TopologyToken, err error) {
// Decide on the selection's order according to whether chronological order
// is requested or not.
var stmt *sql.Stmt
@@ -129,18 +133,27 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit)
if err == sql.ErrNoRows {
// If no event matched the request, return an empty slice.
- return []string{}, nil
+ return []string{}, start, end, nil
} else if err != nil {
return
}
// Return the IDs.
var eventID string
+ var token types.TopologyToken
+ var tokens []types.TopologyToken
for rows.Next() {
- if err = rows.Scan(&eventID); err != nil {
+ if err = rows.Scan(&eventID, &token.Depth, &token.PDUPosition); err != nil {
return
}
eventIDs = append(eventIDs, eventID)
+ tokens = append(tokens, token)
+ }
+
+ // The values are already ordered by SQL, so we can use them as is.
+ if len(tokens) > 0 {
+ start = tokens[0]
+ end = tokens[len(tokens)-1]
}
return
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index f57b0d61..ce7ca3fc 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -213,12 +213,48 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
// backpaginate 5 messages starting at the latest position.
filter := &synctypes.RoomEventFilter{Limit: 5}
- paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
+ paginatedEvents, start, end, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
if err != nil {
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
}
gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil)
test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
+ assert.Equal(t, types.TopologyToken{Depth: 15, PDUPosition: 15}, start)
+ assert.Equal(t, types.TopologyToken{Depth: 11, PDUPosition: 11}, end)
+ })
+ })
+}
+
+// The purpose of this test is to ensure that backfilling returns no start/end if a given filter removes
+// all events.
+func TestGetEventsInRangeWithTopologyTokenNoEventsForFilter(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := MustCreateDatabase(t, dbType)
+ defer close()
+ alice := test.NewUser(t)
+ r := test.NewRoom(t, alice)
+ for i := 0; i < 10; i++ {
+ r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
+ }
+ events := r.Events()
+ _ = MustWriteEvents(t, db, events)
+
+ WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
+ from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
+ t.Logf("max topo pos = %+v", from)
+ // head towards the beginning of time
+ to := types.TopologyToken{}
+
+ // backpaginate 20 messages starting at the latest position.
+ notTypes := []string{spec.MRoomRedaction}
+ senders := []string{alice.ID}
+ filter := &synctypes.RoomEventFilter{Limit: 20, NotTypes: &notTypes, Senders: &senders}
+ paginatedEvents, start, end, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
+ assert.NoError(t, err)
+ assert.Equal(t, 0, len(paginatedEvents))
+ // Even if we didn't get anything back due to the filter, we should still have start/end
+ assert.Equal(t, types.TopologyToken{Depth: 15, PDUPosition: 15}, start)
+ assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 1}, end)
})
})
}
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index 854292bd..f5c66c42 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -89,11 +89,11 @@ type Topology interface {
// InsertEventInTopology inserts the given event in the room's topology, based on the event's depth.
// `pos` is the stream position of this event in the events table, and is used to order events which have the same depth.
InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition) (topoPos types.StreamPosition, err error)
- // SelectEventIDsInRange selects the IDs of events whose depths are within a given range in a given room's topological order.
- // Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`.
+ // SelectEventIDsInRange selects the IDs and the topological position of events whose depths are within a given range in a given room's topological order.
+ // Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`. Returns the eventIDs and start/end topological tokens.
// `maxStreamPos` is only used when events have the same depth as `maxDepth`, which results in events less than `maxStreamPos` being returned.
// Returns an empty slice if no events match the given range.
- SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error)
+ SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, start, end types.TopologyToken, err error)
// SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to.
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go
index f4f75bdf..7691cc5f 100644
--- a/syncapi/storage/tables/topology_test.go
+++ b/syncapi/storage/tables/topology_test.go
@@ -13,6 +13,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
+ "github.com/stretchr/testify/assert"
)
func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
@@ -60,28 +61,37 @@ func TestTopologyTable(t *testing.T) {
highestPos = topoPos + 1
}
// check ordering works without limit
- eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
- if err != nil {
- return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
- }
+ eventIDs, start, end, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
+ assert.NoError(t, err, "failed to SelectEventIDsInRange")
test.AssertEventIDsEqual(t, eventIDs, events[:])
- eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
- if err != nil {
- return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
- }
+ assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, start)
+ assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, end)
+
+ eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
+ assert.NoError(t, err, "failed to SelectEventIDsInRange")
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
+ assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, start)
+ assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, end)
+
// check ordering works with limit
- eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
- if err != nil {
- return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
- }
+ eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
+ assert.NoError(t, err, "failed to SelectEventIDsInRange")
test.AssertEventIDsEqual(t, eventIDs, events[:3])
- eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
- if err != nil {
- return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
- }
+ assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, start)
+ assert.Equal(t, types.TopologyToken{Depth: 3, PDUPosition: 2}, end)
+
+ eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
+ assert.NoError(t, err, "failed to SelectEventIDsInRange")
test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
+ assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, start)
+ assert.Equal(t, types.TopologyToken{Depth: 3, PDUPosition: 2}, end)
+ // Check that we return no values for invalid rooms
+ eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, "!doesnotexist:localhost", 0, highestPos, highestPos, 10, false)
+ assert.NoError(t, err, "failed to SelectEventIDsInRange")
+ assert.Equal(t, 0, len(eventIDs))
+ assert.Equal(t, types.TopologyToken{}, start)
+ assert.Equal(t, types.TopologyToken{}, end)
return nil
})
if err != nil {
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index 64a4af75..af6bddc7 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -144,8 +144,11 @@ func AddPublicRoutes(
logrus.WithError(err).Panicf("failed to start receipts consumer")
}
+ rateLimits := httputil.NewRateLimits(&dendriteCfg.ClientAPI.RateLimiting)
+
routing.Setup(
routers.Client, requestPool, syncDB, userAPI,
rsAPI, &dendriteCfg.SyncAPI, caches, fts,
+ rateLimits,
)
}
diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go
index 19815b79..996b21e9 100644
--- a/syncapi/syncapi_test.go
+++ b/syncapi/syncapi_test.go
@@ -433,6 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
}
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ cfg.ClientAPI.RateLimiting = config.RateLimiting{Enabled: false}
routers := httputil.NewRouters()
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)