aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/storage_test.go
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-05-01 11:01:34 +0100
committerGitHub <noreply@github.com>2020-05-01 11:01:34 +0100
commitb28674435e3024bf0e4723a9cf53180607f2045e (patch)
treef1e7d1f717624bb8419250626109ac16c52af103 /syncapi/storage/storage_test.go
parente15f6676ac3f76ec2ef679c2df300d6a8e7e668f (diff)
Correctly generate backpagination tokens for events which have the same depth (#996)
* Correctly generate backpagination tokens for events which have the same depth With tests. Unfortunately the code around here is hard to understand. There will be a subsequent PR which fixes this up now that we have a test harness in place. * Add postgres impl * More linting * Fix psql statement so it actually works
Diffstat (limited to 'syncapi/storage/storage_test.go')
-rw-r--r--syncapi/storage/storage_test.go108
1 files changed, 104 insertions, 4 deletions
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index e591e7ed..a57d5917 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -182,7 +182,7 @@ func TestSyncResponse(t *testing.T) {
// limit set to 5
return db.CompleteSync(ctx, testUserIDA, 5)
},
- // want the last 5 events, NOT the last 10.
+ // want the last 5 events
WantTimeline: events[len(events)-5:],
// want all state for the room
WantState: state,
@@ -248,11 +248,11 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
db := MustCreateDatabase(t)
events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
MustWriteEvents(t, db, events)
- latest, err := db.MaxTopologicalPosition(ctx, testRoomID)
+ latest, latestStream, err := db.MaxTopologicalPosition(ctx, testRoomID)
if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
}
- from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, 0)
+ from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, latestStream)
// head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0)
@@ -265,6 +265,105 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
}
+// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
+// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
+// will appear FIRST when going backwards. This test creates a DAG like:
+// .-----> Message ---.
+// Create -> Membership --------> Message -------> Message
+// `-----> Message ---`
+// depth 1 2 3 4
+//
+// With a total depth of 4. It tests that:
+// - Backpagination over the whole fork should include all messages and not leave any out.
+// - Backpagination from the middle of the fork should not return duplicates (things later than the token).
+func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) {
+ t.Parallel()
+ db := MustCreateDatabase(t)
+
+ var events []gomatrixserverlib.HeaderedEvent
+ events = append(events, MustCreateEvent(t, testRoomID, nil, &gomatrixserverlib.EventBuilder{
+ Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)),
+ Type: "m.room.create",
+ StateKey: &emptyStateKey,
+ Sender: testUserIDA,
+ Depth: int64(len(events) + 1),
+ }))
+ events = append(events, MustCreateEvent(t, testRoomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
+ Content: []byte(fmt.Sprintf(`{"membership":"join"}`)),
+ Type: "m.room.member",
+ StateKey: &testUserIDA,
+ Sender: testUserIDA,
+ Depth: int64(len(events) + 1),
+ }))
+ // fork the dag into three, same prev_events and depth
+ parent := []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}
+ depth := int64(len(events) + 1)
+ for i := 0; i < 3; i++ {
+ events = append(events, MustCreateEvent(t, testRoomID, parent, &gomatrixserverlib.EventBuilder{
+ Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)),
+ Type: "m.room.message",
+ Sender: testUserIDA,
+ Depth: depth,
+ }))
+ }
+ // merge the fork, prev_events are all 3 messages, depth is increased by 1.
+ events = append(events, MustCreateEvent(t, testRoomID, events[len(events)-3:], &gomatrixserverlib.EventBuilder{
+ Content: []byte(fmt.Sprintf(`{"body":"Message merge"}`)),
+ Type: "m.room.message",
+ Sender: testUserIDA,
+ Depth: depth + 1,
+ }))
+ MustWriteEvents(t, db, events)
+ latestPos, latestStreamPos, err := db.EventPositionInTopology(ctx, events[len(events)-1].EventID())
+ if err != nil {
+ t.Fatalf("failed to get EventPositionInTopology: %s", err)
+ }
+ topoPos, streamPos, err := db.EventPositionInTopology(ctx, events[len(events)-3].EventID()) // Message 2
+ if err != nil {
+ t.Fatalf("failed to get EventPositionInTopology for event: %s", err)
+ }
+ fromLatest := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latestPos, latestStreamPos)
+ fromFork := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, topoPos, streamPos)
+ // head towards the beginning of time
+ to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0)
+
+ testCases := []struct {
+ Name string
+ From *types.PaginationToken
+ Limit int
+ Wants []gomatrixserverlib.HeaderedEvent
+ }{
+ {
+ Name: "Pagination over the whole fork",
+ From: fromLatest,
+ Limit: 5,
+ Wants: reversed(events[len(events)-5:]),
+ },
+ {
+ Name: "Paginating to the middle of the fork",
+ From: fromLatest,
+ Limit: 2,
+ Wants: reversed(events[len(events)-2:]),
+ },
+ {
+ Name: "Pagination FROM the middle of the fork",
+ From: fromFork,
+ Limit: 3,
+ Wants: reversed(events[len(events)-5 : len(events)-2]),
+ },
+ }
+
+ for _, tc := range testCases {
+ // backpaginate messages starting at the latest position.
+ paginatedEvents, err := db.GetEventsInRange(ctx, tc.From, to, testRoomID, tc.Limit, true)
+ if err != nil {
+ t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err)
+ }
+ gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
+ assertEventsEqual(t, tc.Name, true, gots, tc.Wants)
+ }
+}
+
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
if len(gots) != len(wants) {
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))
@@ -294,7 +393,8 @@ func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatr
t.Errorf("%s event[%d] unsigned mismatch: got %s want %s", msg, i, string(g.Unsigned), string(w.Unsigned()))
}
if (g.StateKey == nil && w.StateKey() != nil) || (g.StateKey != nil && w.StateKey() == nil) {
- t.Fatalf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey())
+ t.Errorf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey())
+ continue
}
if g.StateKey != nil {
if !w.StateKeyEquals(*g.StateKey) {