diff options
author | Kegsay <kegan@matrix.org> | 2020-05-01 11:01:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-01 11:01:34 +0100 |
commit | b28674435e3024bf0e4723a9cf53180607f2045e (patch) | |
tree | f1e7d1f717624bb8419250626109ac16c52af103 /syncapi/storage/storage_test.go | |
parent | e15f6676ac3f76ec2ef679c2df300d6a8e7e668f (diff) |
Correctly generate backpagination tokens for events which have the same depth (#996)
* Correctly generate backpagination tokens for events which have the same depth
With tests. Unfortunately the code around here is hard to understand.
There will be a subsequent PR which fixes this up now that we have a test
harness in place.
* Add postgres impl
* More linting
* Fix psql statement so it actually works
Diffstat (limited to 'syncapi/storage/storage_test.go')
-rw-r--r-- | syncapi/storage/storage_test.go | 108 |
1 files changed, 104 insertions, 4 deletions
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index e591e7ed..a57d5917 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -182,7 +182,7 @@ func TestSyncResponse(t *testing.T) { // limit set to 5 return db.CompleteSync(ctx, testUserIDA, 5) }, - // want the last 5 events, NOT the last 10. + // want the last 5 events WantTimeline: events[len(events)-5:], // want all state for the room WantState: state, @@ -248,11 +248,11 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { db := MustCreateDatabase(t) events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB) MustWriteEvents(t, db, events) - latest, err := db.MaxTopologicalPosition(ctx, testRoomID) + latest, latestStream, err := db.MaxTopologicalPosition(ctx, testRoomID) if err != nil { t.Fatalf("failed to get MaxTopologicalPosition: %s", err) } - from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, 0) + from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, latestStream) // head towards the beginning of time to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) @@ -265,6 +265,105 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:])) } +// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. +// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent +// will appear FIRST when going backwards. This test creates a DAG like: +// .-----> Message ---. +// Create -> Membership --------> Message -------> Message +// `-----> Message ---` +// depth 1 2 3 4 +// +// With a total depth of 4. It tests that: +// - Backpagination over the whole fork should include all messages and not leave any out. +// - Backpagination from the middle of the fork should not return duplicates (things later than the token). +func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) { + t.Parallel() + db := MustCreateDatabase(t) + + var events []gomatrixserverlib.HeaderedEvent + events = append(events, MustCreateEvent(t, testRoomID, nil, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, testUserIDA)), + Type: "m.room.create", + StateKey: &emptyStateKey, + Sender: testUserIDA, + Depth: int64(len(events) + 1), + })) + events = append(events, MustCreateEvent(t, testRoomID, []gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"membership":"join"}`)), + Type: "m.room.member", + StateKey: &testUserIDA, + Sender: testUserIDA, + Depth: int64(len(events) + 1), + })) + // fork the dag into three, same prev_events and depth + parent := []gomatrixserverlib.HeaderedEvent{events[len(events)-1]} + depth := int64(len(events) + 1) + for i := 0; i < 3; i++ { + events = append(events, MustCreateEvent(t, testRoomID, parent, &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)), + Type: "m.room.message", + Sender: testUserIDA, + Depth: depth, + })) + } + // merge the fork, prev_events are all 3 messages, depth is increased by 1. + events = append(events, MustCreateEvent(t, testRoomID, events[len(events)-3:], &gomatrixserverlib.EventBuilder{ + Content: []byte(fmt.Sprintf(`{"body":"Message merge"}`)), + Type: "m.room.message", + Sender: testUserIDA, + Depth: depth + 1, + })) + MustWriteEvents(t, db, events) + latestPos, latestStreamPos, err := db.EventPositionInTopology(ctx, events[len(events)-1].EventID()) + if err != nil { + t.Fatalf("failed to get EventPositionInTopology: %s", err) + } + topoPos, streamPos, err := db.EventPositionInTopology(ctx, events[len(events)-3].EventID()) // Message 2 + if err != nil { + t.Fatalf("failed to get EventPositionInTopology for event: %s", err) + } + fromLatest := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latestPos, latestStreamPos) + fromFork := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, topoPos, streamPos) + // head towards the beginning of time + to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) + + testCases := []struct { + Name string + From *types.PaginationToken + Limit int + Wants []gomatrixserverlib.HeaderedEvent + }{ + { + Name: "Pagination over the whole fork", + From: fromLatest, + Limit: 5, + Wants: reversed(events[len(events)-5:]), + }, + { + Name: "Paginating to the middle of the fork", + From: fromLatest, + Limit: 2, + Wants: reversed(events[len(events)-2:]), + }, + { + Name: "Pagination FROM the middle of the fork", + From: fromFork, + Limit: 3, + Wants: reversed(events[len(events)-5 : len(events)-2]), + }, + } + + for _, tc := range testCases { + // backpaginate messages starting at the latest position. + paginatedEvents, err := db.GetEventsInRange(ctx, tc.From, to, testRoomID, tc.Limit, true) + if err != nil { + t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err) + } + gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll) + assertEventsEqual(t, tc.Name, true, gots, tc.Wants) + } +} + func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) { if len(gots) != len(wants) { t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants)) @@ -294,7 +393,8 @@ func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatr t.Errorf("%s event[%d] unsigned mismatch: got %s want %s", msg, i, string(g.Unsigned), string(w.Unsigned())) } if (g.StateKey == nil && w.StateKey() != nil) || (g.StateKey != nil && w.StateKey() == nil) { - t.Fatalf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey()) + t.Errorf("%s event[%d] state_key [not] missing: got %v want %v", msg, i, g.StateKey, w.StateKey()) + continue } if g.StateKey != nil { if !w.StateKeyEquals(*g.StateKey) { |