aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkegsay <kegan@matrix.org>2022-04-08 17:53:24 +0100
committerGitHub <noreply@github.com>2022-04-08 17:53:24 +0100
commit6d25bd6ca57f518404000c47d69bcbfadb4fd2ef (patch)
tree14246daf0644ccbafd255b930887695382bbe8f8
parent986d27a1287e9d86fe16a6152f2457657513a6dd (diff)
syncapi: add more tests; fix more bugs (#2338)
* syncapi: add more tests; fix more bugs bugfixes: - The postgres impl of TopologyTable.SelectEventIDsInRange did not use the provided txn - The postgres impl of EventsTable.SelectEvents did not preserve the ordering of the input event IDs in the output events slice - The sqlite impl of EventsTable.SelectEvents did not use a bulk `IN ($1)` query. Added tests: - `TestGetEventsInRangeWithTopologyToken` - `TestOutputRoomEventsTable` - `TestTopologyTable` * -p 1 for now
-rw-r--r--.github/workflows/dendrite.yml2
-rw-r--r--syncapi/storage/interface.go2
-rw-r--r--syncapi/storage/postgres/output_room_events_table.go22
-rw-r--r--syncapi/storage/postgres/output_room_events_topology_table.go4
-rw-r--r--syncapi/storage/shared/syncserver.go8
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go4
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go4
-rw-r--r--syncapi/storage/sqlite3/invites_table.go4
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go52
-rw-r--r--syncapi/storage/sqlite3/peeks_table.go4
-rw-r--r--syncapi/storage/sqlite3/presence_table.go4
-rw-r--r--syncapi/storage/sqlite3/receipt_table.go4
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go14
-rw-r--r--syncapi/storage/sqlite3/syncserver.go4
-rw-r--r--syncapi/storage/storage_test.go234
-rw-r--r--syncapi/storage/tables/interface.go2
-rw-r--r--syncapi/storage/tables/output_room_events_test.go82
-rw-r--r--syncapi/storage/tables/topology_test.go91
-rw-r--r--test/db.go1
-rw-r--r--test/event.go39
20 files changed, 386 insertions, 195 deletions
diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml
index 4f337a86..8221bff9 100644
--- a/.github/workflows/dendrite.yml
+++ b/.github/workflows/dendrite.yml
@@ -111,7 +111,7 @@ jobs:
key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-
- - run: go test ./...
+ - run: go test -p 1 ./...
env:
POSTGRES_HOST: localhost
POSTGRES_USER: postgres
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 841f6726..cf3fd553 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -104,7 +104,7 @@ type Database interface {
// DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
- // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
+ // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go
index 14af6a94..a30e220b 100644
--- a/syncapi/storage/postgres/output_room_events_table.go
+++ b/syncapi/storage/postgres/output_room_events_table.go
@@ -427,7 +427,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents(
- ctx context.Context, txn *sql.Tx, eventIDs []string,
+ ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
@@ -435,7 +435,25 @@ func (s *outputRoomEventsStatements) SelectEvents(
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
- return rowsToStreamEvents(rows)
+ streamEvents, err := rowsToStreamEvents(rows)
+ if err != nil {
+ return nil, err
+ }
+ if preserveOrder {
+ eventMap := make(map[string]types.StreamEvent)
+ for _, ev := range streamEvents {
+ eventMap[ev.EventID()] = ev
+ }
+ var returnEvents []types.StreamEvent
+ for _, eventID := range eventIDs {
+ ev, ok := eventMap[eventID]
+ if ok {
+ returnEvents = append(returnEvents, ev)
+ }
+ }
+ return returnEvents, nil
+ }
+ return streamEvents, nil
}
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go
index 626386ba..90b3b008 100644
--- a/syncapi/storage/postgres/output_room_events_topology_table.go
+++ b/syncapi/storage/postgres/output_room_events_topology_table.go
@@ -148,9 +148,9 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
// is requested or not.
var stmt *sql.Stmt
if chronologicalOrder {
- stmt = s.selectEventIDsInRangeASCStmt
+ stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeASCStmt)
} else {
- stmt = s.selectEventIDsInRangeDESCStmt
+ stmt = sqlutil.TxStmt(txn, s.selectEventIDsInRangeDESCStmt)
}
// Query the event IDs.
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index 1c45d5d9..14db5795 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
// Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
- streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs)
+ streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false)
if err != nil {
return nil, err
}
@@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
// Check if we have all of the event's previous events. If an event is
// missing, add it to the room's backward extremities.
- prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs())
+ prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false)
if err != nil {
return err
}
@@ -457,7 +457,7 @@ func (d *Database) GetEventsInTopologicalRange(
}
// Retrieve the events' contents using their IDs.
- events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs)
+ events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true)
return
}
@@ -619,7 +619,7 @@ func (d *Database) fetchMissingStateEvents(
) ([]types.StreamEvent, error) {
// Fetch from the events table first so we pick up the stream ID for the
// event.
- events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs)
+ events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false)
if err != nil {
return nil, err
}
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
index 24c44224..5b2287e6 100644
--- a/syncapi/storage/sqlite3/account_data_table.go
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -51,13 +51,13 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct {
db *sql.DB
- streamIDStatements *streamIDStatements
+ streamIDStatements *StreamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
selectAccountDataInRangeStmt *sql.Stmt
}
-func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
+func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
streamIDStatements: streamID,
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index 473aa49b..464f32e0 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
- streamIDStatements *streamIDStatements
+ streamIDStatements *StreamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
deleteRoomStateForRoomStmt *sql.Stmt
@@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
selectStateEventStmt *sql.Stmt
}
-func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
+func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
streamIDStatements: streamID,
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
index 0a6823cc..58ab8461 100644
--- a/syncapi/storage/sqlite3/invites_table.go
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct {
db *sql.DB
- streamIDStatements *streamIDStatements
+ streamIDStatements *StreamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
deleteInviteEventStmt *sql.Stmt
selectMaxInviteIDStmt *sql.Stmt
}
-func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
+func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
streamIDStatements: streamID,
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index acd95969..9da9d776 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -58,7 +58,7 @@ const insertEventSQL = "" +
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
const selectEventsSQL = "" +
- "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
+ "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
@@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
type outputRoomEventsStatements struct {
db *sql.DB
- streamIDStatements *streamIDStatements
+ streamIDStatements *StreamIDStatements
insertEventStmt *sql.Stmt
- selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
@@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
selectContextAfterEventStmt *sql.Stmt
}
-func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
+func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
streamIDStatements: streamID,
@@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
}
return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL},
- {&s.selectEventsStmt, selectEventsSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
{&s.updateEventJSONStmt, updateEventJSONSQL},
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
@@ -421,21 +419,43 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
// selectEvents returns the events for the given event IDs. If an event is
// missing from the database, it will be omitted.
func (s *outputRoomEventsStatements) SelectEvents(
- ctx context.Context, txn *sql.Tx, eventIDs []string,
+ ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
) ([]types.StreamEvent, error) {
- var returnEvents []types.StreamEvent
- stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
- for _, eventID := range eventIDs {
- rows, err := stmt.QueryContext(ctx, eventID)
- if err != nil {
- return nil, err
+ iEventIDs := make([]interface{}, len(eventIDs))
+ for i := range eventIDs {
+ iEventIDs[i] = eventIDs[i]
+ }
+ selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
+ var rows *sql.Rows
+ var err error
+ if txn != nil {
+ rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...)
+ } else {
+ rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
+ streamEvents, err := rowsToStreamEvents(rows)
+ if err != nil {
+ return nil, err
+ }
+ if preserveOrder {
+ var returnEvents []types.StreamEvent
+ eventMap := make(map[string]types.StreamEvent)
+ for _, ev := range streamEvents {
+ eventMap[ev.EventID()] = ev
}
- if streamEvents, err := rowsToStreamEvents(rows); err == nil {
- returnEvents = append(returnEvents, streamEvents...)
+ for _, eventID := range eventIDs {
+ ev, ok := eventMap[eventID]
+ if ok {
+ returnEvents = append(returnEvents, ev)
+ }
}
- internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
+ return returnEvents, nil
}
- return returnEvents, nil
+ return streamEvents, nil
}
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go
index c93c8205..5ee86448 100644
--- a/syncapi/storage/sqlite3/peeks_table.go
+++ b/syncapi/storage/sqlite3/peeks_table.go
@@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
type peekStatements struct {
db *sql.DB
- streamIDStatements *streamIDStatements
+ streamIDStatements *StreamIDStatements
insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt
@@ -75,7 +75,7 @@ type peekStatements struct {
selectMaxPeekIDStmt *sql.Stmt
}
-func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
+func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema)
if err != nil {
return nil, err
diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go
index e7b78a70..00b16458 100644
--- a/syncapi/storage/sqlite3/presence_table.go
+++ b/syncapi/storage/sqlite3/presence_table.go
@@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
type presenceStatements struct {
db *sql.DB
- streamIDStatements *streamIDStatements
+ streamIDStatements *StreamIDStatements
upsertPresenceStmt *sql.Stmt
upsertPresenceFromSyncStmt *sql.Stmt
selectPresenceForUsersStmt *sql.Stmt
@@ -83,7 +83,7 @@ type presenceStatements struct {
selectPresenceAfterStmt *sql.Stmt
}
-func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
+func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
_, err := db.Exec(presenceSchema)
if err != nil {
return nil, err
diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go
index dea05771..bd778bf3 100644
--- a/syncapi/storage/sqlite3/receipt_table.go
+++ b/syncapi/storage/sqlite3/receipt_table.go
@@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
type receiptStatements struct {
db *sql.DB
- streamIDStatements *streamIDStatements
+ streamIDStatements *StreamIDStatements
upsertReceipt *sql.Stmt
selectRoomReceipts *sql.Stmt
selectMaxReceiptID *sql.Stmt
}
-func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
+func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
_, err := db.Exec(receiptsSchema)
if err != nil {
return nil, err
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index faa2c41f..71980b80 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
" RETURNING stream_id"
-type streamIDStatements struct {
+type StreamIDStatements struct {
db *sql.DB
increaseStreamIDStmt *sql.Stmt
}
-func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
+func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
s.db = db
_, err = db.Exec(streamIDTableSchema)
if err != nil {
@@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
return
}
-func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
-func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}
-func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
-func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}
-func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
+func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index 9d9d3598..dfc28948 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -30,7 +30,7 @@ type SyncServerDatasource struct {
shared.Database
db *sql.DB
writer sqlutil.Writer
- streamID streamIDStatements
+ streamID StreamIDStatements
}
// NewDatabase creates a new sync server database
@@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
- if err = d.streamID.prepare(d.db); err != nil {
+ if err = d.streamID.Prepare(d.db); err != nil {
return err
}
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index 403b50ea..4e1634ec 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -3,6 +3,7 @@ package storage_test
import (
"context"
"fmt"
+ "reflect"
"testing"
"github.com/matrix-org/dendrite/setup/config"
@@ -38,7 +39,7 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
if err != nil {
t.Fatalf("WriteEvent failed: %s", err)
}
- fmt.Println("Event ID", ev.EventID(), " spos=", pos, "depth=", ev.Depth())
+ t.Logf("Event ID %s spos=%v depth=%v", ev.EventID(), pos, ev.Depth())
positions = append(positions, pos)
}
return
@@ -46,7 +47,6 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
func TestWriteEvents(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- t.Parallel()
alice := test.NewUser()
r := test.NewRoom(t, alice)
db, close := MustCreateDatabase(t, dbType)
@@ -61,84 +61,84 @@ func TestRecentEventsPDU(t *testing.T) {
db, close := MustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
- var filter gomatrixserverlib.RoomEventFilter
- filter.Limit = 100
+ // dummy room to make sure SQL queries are filtering on room ID
+ MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
+
+ // actual test room
r := test.NewRoom(t, alice)
r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
events := r.Events()
positions := MustWriteEvents(t, db, events)
+
+ // dummy room to make sure SQL queries are filtering on room ID
+ MustWriteEvents(t, db, test.NewRoom(t, alice).Events())
+
latest, err := db.MaxStreamPositionForPDUs(ctx)
if err != nil {
t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
}
testCases := []struct {
- Name string
- From types.StreamPosition
- To types.StreamPosition
- WantEvents []*gomatrixserverlib.HeaderedEvent
- WantLimited bool
+ Name string
+ From types.StreamPosition
+ To types.StreamPosition
+ Limit int
+ ReverseOrder bool
+ WantEvents []*gomatrixserverlib.HeaderedEvent
+ WantLimited bool
}{
// The purpose of this test is to make sure that incremental syncs are including up to the latest events.
- // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
+ // It's a basic sanity test that sync works. It creates a streaming position that is on the penultimate event.
// It makes sure the response includes the final event.
{
- Name: "IncrementalSync penultimate",
+ Name: "penultimate",
From: positions[len(positions)-2], // pretend we are at the penultimate event
To: latest,
+ Limit: 100,
WantEvents: events[len(events)-1:],
WantLimited: false,
},
- /*
- // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
- // number of returned events. This is critical for big rooms hence the test here.
- {
- Name: "IncrementalSync limited",
- DoSync: func() (*types.Response, error) {
- from := types.StreamingToken{ // pretend we are 10 events behind
- PDUPosition: positions[len(positions)-11],
- }
- res := types.NewResponse()
- // limit is set to 5
- return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
- },
- // want the last 5 events, NOT the last 10.
- WantTimeline: events[len(events)-5:],
- },
- // The purpose of this test is to check that CompleteSync returns all the current state as well as
- // honouring the `numRecentEventsPerRoom` value
- {
- Name: "CompleteSync limited",
- DoSync: func() (*types.Response, error) {
- res := types.NewResponse()
- // limit set to 5
- return db.CompleteSync(ctx, res, testUserDeviceA, 5)
- },
- // want the last 5 events
- WantTimeline: events[len(events)-5:],
- // want all state for the room
- WantState: state,
- },
- // The purpose of this test is to check that CompleteSync can return everything with a high enough
- // `numRecentEventsPerRoom`.
- {
- Name: "CompleteSync",
- DoSync: func() (*types.Response, error) {
- res := types.NewResponse()
- return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
- },
- WantTimeline: events,
- // We want no state at all as that field in /sync is the delta between the token (beginning of time)
- // and the START of the timeline.
- }, */
+ // The purpose of this test is to check that limits can be applied and work.
+ // This is critical for big rooms hence the test here.
+ {
+ Name: "limited",
+ From: 0,
+ To: latest,
+ Limit: 1,
+ WantEvents: events[len(events)-1:],
+ WantLimited: true,
+ },
+ // The purpose of this test is to check that we can return every event with a high
+ // enough limit
+ {
+ Name: "large limited",
+ From: 0,
+ To: latest,
+ Limit: 100,
+ WantEvents: events,
+ WantLimited: false,
+ },
+ // The purpose of this test is to check that we can return events in reverse order
+ {
+ Name: "reverse",
+ From: positions[len(positions)-3], // 2 events back
+ To: latest,
+ Limit: 100,
+ ReverseOrder: true,
+ WantEvents: test.Reversed(events[len(events)-2:]),
+ WantLimited: false,
+ },
}
- for _, tc := range testCases {
+ for i := range testCases {
+ tc := testCases[i]
t.Run(tc.Name, func(st *testing.T) {
+ var filter gomatrixserverlib.RoomEventFilter
+ filter.Limit = tc.Limit
gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
From: tc.From,
To: tc.To,
- }, &filter, true, true)
+ }, &filter, !tc.ReverseOrder, true)
if err != nil {
st.Fatalf("failed to do sync: %s", err)
}
@@ -148,100 +148,48 @@ func TestRecentEventsPDU(t *testing.T) {
if len(gotEvents) != len(tc.WantEvents) {
st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
}
+ for j := range gotEvents {
+ if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) {
+ st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON()))
+ }
+ }
})
}
})
}
-/*
-func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
- t.Parallel()
- db := MustCreateDatabase(t)
- events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
- positions := MustWriteEvents(t, db, events)
- latest, err := db.SyncPosition(ctx)
- if err != nil {
- t.Fatalf("failed to get SyncPosition: %s", err)
- }
- from := types.StreamingToken{
- PDUPosition: positions[len(positions)-2],
- }
-
- res := types.NewResponse()
- res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
- if err != nil {
- t.Fatalf("failed to IncrementalSync with latest token")
- }
- roomRes, ok := res.Rooms.Join[testRoomID]
- if !ok {
- t.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
- }
- // returns the last event "Message 10"
- assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
-
- prev := roomRes.Timeline.PrevBatch.String()
- if prev == "" {
- t.Fatalf("IncrementalSync expected prev_batch token")
- }
- prevBatchToken, err := types.NewTopologyTokenFromString(prev)
- if err != nil {
- t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
- }
- // backpaginate 5 messages starting at the latest position.
- // head towards the beginning of time
- to := types.TopologyToken{}
- paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
- if err != nil {
- t.Fatalf("GetEventsInRange returned an error: %s", err)
- }
- gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
- assertEventsEqual(t, "", true, gots, reversed(events[len(events)-6:len(events)-1]))
-}
-
-// The purpose of this test is to ensure that backfill does indeed go backwards, using a stream token.
-func TestGetEventsInRangeWithStreamToken(t *testing.T) {
- t.Parallel()
- db := MustCreateDatabase(t)
- events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
- MustWriteEvents(t, db, events)
- latest, err := db.SyncPosition(ctx)
- if err != nil {
- t.Fatalf("failed to get SyncPosition: %s", err)
- }
- // head towards the beginning of time
- to := types.StreamingToken{}
-
- // backpaginate 5 messages starting at the latest position.
- paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
- if err != nil {
- t.Fatalf("GetEventsInRange returned an error: %s", err)
- }
- gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
- assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
-}
-
// The purpose of this test is to ensure that backfill does indeed go backwards, using a topology token
func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
- t.Parallel()
- db := MustCreateDatabase(t)
- events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
- MustWriteEvents(t, db, events)
- from, err := db.MaxTopologicalPosition(ctx, testRoomID)
- if err != nil {
- t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
- }
- // head towards the beginning of time
- to := types.TopologyToken{}
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := MustCreateDatabase(t, dbType)
+ defer close()
+ alice := test.NewUser()
+ r := test.NewRoom(t, alice)
+ for i := 0; i < 10; i++ {
+ r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)})
+ }
+ events := r.Events()
+ _ = MustWriteEvents(t, db, events)
- // backpaginate 5 messages starting at the latest position.
- paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
- if err != nil {
- t.Fatalf("GetEventsInRange returned an error: %s", err)
- }
- gots := gomatrixserverlib.HeaderedToClientEvents(db.StreamEventsToEvents(&testUserDeviceA, paginatedEvents), gomatrixserverlib.FormatAll)
- assertEventsEqual(t, "", true, gots, reversed(events[len(events)-5:]))
+ from, err := db.MaxTopologicalPosition(ctx, r.ID)
+ if err != nil {
+ t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
+ }
+ t.Logf("max topo pos = %+v", from)
+ // head towards the beginning of time
+ to := types.TopologyToken{}
+
+ // backpaginate 5 messages starting at the latest position.
+ paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true)
+ if err != nil {
+ t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
+ }
+ gots := db.StreamEventsToEvents(nil, paginatedEvents)
+ test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:]))
+ })
}
+/*
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
// will appear FIRST when going backwards. This test creates a DAG like:
@@ -651,12 +599,4 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ
tok.Decrement()
return &tok
}
-
-func reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
- out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
- for i := 0; i < len(in); i++ {
- out[i] = in[len(in)-i-1]
- }
- return out
-}
*/
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index 8d368eec..3cbeb046 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -59,7 +59,7 @@ type Events interface {
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
- SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
+ SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
diff --git a/syncapi/storage/tables/output_room_events_test.go b/syncapi/storage/tables/output_room_events_test.go
new file mode 100644
index 00000000..7a81ffcd
--- /dev/null
+++ b/syncapi/storage/tables/output_room_events_test.go
@@ -0,0 +1,82 @@
+package tables_test
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/syncapi/storage/postgres"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/test"
+)
+
+func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to open db: %s", err)
+ }
+
+ var tab tables.Events
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresEventsTable(db)
+ case test.DBTypeSQLite:
+ var stream sqlite3.StreamIDStatements
+ if err = stream.Prepare(db); err != nil {
+ t.Fatalf("failed to prepare stream stmts: %s", err)
+ }
+ tab, err = sqlite3.NewSqliteEventsTable(db, &stream)
+ }
+ if err != nil {
+ t.Fatalf("failed to make new table: %s", err)
+ }
+ return tab, db, close
+}
+
+func TestOutputRoomEventsTable(t *testing.T) {
+ ctx := context.Background()
+ alice := test.NewUser()
+ room := test.NewRoom(t, alice)
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, db, close := newOutputRoomEventsTable(t, dbType)
+ defer close()
+ events := room.Events()
+ err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
+ for _, ev := range events {
+ _, err := tab.InsertEvent(ctx, txn, ev, nil, nil, nil, false)
+ if err != nil {
+ return fmt.Errorf("failed to InsertEvent: %s", err)
+ }
+ }
+ // order = 2,0,3,1
+ wantEventIDs := []string{
+ events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
+ }
+ gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true)
+ if err != nil {
+ return fmt.Errorf("failed to SelectEvents: %s", err)
+ }
+ gotEventIDs := make([]string, len(gotEvents))
+ for i := range gotEvents {
+ gotEventIDs[i] = gotEvents[i].EventID()
+ }
+ if !reflect.DeepEqual(gotEventIDs, wantEventIDs) {
+ return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
+ }
+
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ })
+}
diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go
new file mode 100644
index 00000000..b6ece0b0
--- /dev/null
+++ b/syncapi/storage/tables/topology_test.go
@@ -0,0 +1,91 @@
+package tables_test
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/syncapi/storage/postgres"
+ "github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/test"
+)
+
+func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to open db: %s", err)
+ }
+
+ var tab tables.Topology
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresTopologyTable(db)
+ case test.DBTypeSQLite:
+ tab, err = sqlite3.NewSqliteTopologyTable(db)
+ }
+ if err != nil {
+ t.Fatalf("failed to make new table: %s", err)
+ }
+ return tab, db, close
+}
+
+func TestTopologyTable(t *testing.T) {
+ ctx := context.Background()
+ alice := test.NewUser()
+ room := test.NewRoom(t, alice)
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, db, close := newTopologyTable(t, dbType)
+ defer close()
+ events := room.Events()
+ err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
+ var highestPos types.StreamPosition
+ for i, ev := range events {
+ topoPos, err := tab.InsertEventInTopology(ctx, txn, ev, types.StreamPosition(i))
+ if err != nil {
+ return fmt.Errorf("failed to InsertEventInTopology: %s", err)
+ }
+ // topo pos = depth, depth starts at 1, hence 1+i
+ if topoPos != types.StreamPosition(1+i) {
+ return fmt.Errorf("got topo pos %d want %d", topoPos, 1+i)
+ }
+ highestPos = topoPos + 1
+ }
+ // check ordering works without limit
+ eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true)
+ if err != nil {
+ return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
+ }
+ test.AssertEventIDsEqual(t, eventIDs, events[:])
+ eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false)
+ if err != nil {
+ return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
+ }
+ test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:]))
+ // check ordering works with limit
+ eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true)
+ if err != nil {
+ return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
+ }
+ test.AssertEventIDsEqual(t, eventIDs, events[:3])
+ eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false)
+ if err != nil {
+ return fmt.Errorf("failed to SelectEventIDsInRange: %s", err)
+ }
+ test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:]))
+
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("err: %s", err)
+ }
+ })
+}
diff --git a/test/db.go b/test/db.go
index 9deec0a8..674fdf5c 100644
--- a/test/db.go
+++ b/test/db.go
@@ -121,6 +121,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
for dbName, dbType := range dbs {
dbt := dbType
t.Run(dbName, func(tt *testing.T) {
+ tt.Parallel()
testFn(tt, dbt)
})
}
diff --git a/test/event.go b/test/event.go
index 487b0936..b2e2805b 100644
--- a/test/event.go
+++ b/test/event.go
@@ -15,7 +15,9 @@
package test
import (
+ "bytes"
"crypto/ed25519"
+ "testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
@@ -49,3 +51,40 @@ func WithUnsigned(unsigned interface{}) eventModifier {
e.unsigned = unsigned
}
}
+
+// Reverse a list of events
+func Reversed(in []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
+ out := make([]*gomatrixserverlib.HeaderedEvent, len(in))
+ for i := 0; i < len(in); i++ {
+ out[i] = in[len(in)-i-1]
+ }
+ return out
+}
+
+func AssertEventIDsEqual(t *testing.T, gotEventIDs []string, wants []*gomatrixserverlib.HeaderedEvent) {
+ t.Helper()
+ if len(gotEventIDs) != len(wants) {
+ t.Fatalf("length mismatch: got %d events, want %d", len(gotEventIDs), len(wants))
+ }
+ for i := range wants {
+ w := wants[i].EventID()
+ g := gotEventIDs[i]
+ if w != g {
+ t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
+ }
+ }
+}
+
+func AssertEventsEqual(t *testing.T, gots, wants []*gomatrixserverlib.HeaderedEvent) {
+ t.Helper()
+ if len(gots) != len(wants) {
+ t.Fatalf("length mismatch: got %d events, want %d", len(gots), len(wants))
+ }
+ for i := range wants {
+ w := wants[i].JSON()
+ g := gots[i].JSON()
+ if !bytes.Equal(w, g) {
+ t.Errorf("event at index %d mismatch:\ngot %s\n\nwant %s", i, string(g), string(w))
+ }
+ }
+}