aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkegsay <kegan@matrix.org>2022-04-08 10:12:30 +0100
committerGitHub <noreply@github.com>2022-04-08 10:12:30 +0100
commit7499147550110d24fa3a376bd811d9dd38971629 (patch)
tree335f11802f6cd391effddae9709b014ed1a17c58
parent955e6eb307c78594fe9614f6a304dc521ba28d49 (diff)
Add test infrastructure code for dendrite unit/integ tests (#2331)
* Add test infrastructure code for dendrite unit/integ tests Start re-enabling some syncapi storage tests in the process. * Linting * Add postgres service to unit tests * dendrite not syncv3 * Skip test which doesn't work * Linting * Add `jetstream.PrepareForTests` Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
-rw-r--r--.github/workflows/dendrite.yml25
-rw-r--r--roomserver/internal/input/input_test.go41
-rw-r--r--setup/jetstream/nats.go10
-rw-r--r--syncapi/storage/storage_test.go307
-rw-r--r--test/db.go127
-rw-r--r--test/event.go51
-rw-r--r--test/room.go223
-rw-r--r--test/user.go36
8 files changed, 598 insertions, 222 deletions
diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml
index c80a82d1..4f337a86 100644
--- a/.github/workflows/dendrite.yml
+++ b/.github/workflows/dendrite.yml
@@ -73,6 +73,26 @@ jobs:
timeout-minutes: 5
name: Unit tests (Go ${{ matrix.go }})
runs-on: ubuntu-latest
+ # Service containers to run with `container-job`
+ services:
+ # Label used to access the service container
+ postgres:
+ # Docker Hub image
+ image: postgres:13-alpine
+ # Provide the password for postgres
+ env:
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_DB: dendrite
+ ports:
+ # Maps tcp port 5432 on service container to the host
+ - 5432:5432
+ # Set health checks to wait until postgres has started
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
strategy:
fail-fast: false
matrix:
@@ -92,6 +112,11 @@ jobs:
restore-keys: |
${{ runner.os }}-go${{ matrix.go }}-test-
- run: go test ./...
+ env:
+ POSTGRES_HOST: localhost
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_DB: dendrite
# build Dendrite for linux with different architectures and go versions
build:
diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go
index 4fa96628..81c86ae3 100644
--- a/roomserver/internal/input/input_test.go
+++ b/roomserver/internal/input/input_test.go
@@ -2,7 +2,6 @@ package input_test
import (
"context"
- "fmt"
"os"
"testing"
"time"
@@ -12,30 +11,22 @@ import (
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
)
-func psqlConnectionString() config.DataSource {
- user := os.Getenv("POSTGRES_USER")
- if user == "" {
- user = "dendrite"
- }
- dbName := os.Getenv("POSTGRES_DB")
- if dbName == "" {
- dbName = "dendrite"
- }
- connStr := fmt.Sprintf(
- "user=%s dbname=%s sslmode=disable", user, dbName,
- )
- password := os.Getenv("POSTGRES_PASSWORD")
- if password != "" {
- connStr += fmt.Sprintf(" password=%s", password)
- }
- host := os.Getenv("POSTGRES_HOST")
- if host != "" {
- connStr += fmt.Sprintf(" host=%s", host)
- }
- return config.DataSource(connStr)
+var js nats.JetStreamContext
+var jc *nats.Conn
+
+func TestMain(m *testing.M) {
+ var pc *process.ProcessContext
+ pc, js, jc = jetstream.PrepareForTests()
+ code := m.Run()
+ pc.ShutdownDendrite()
+ pc.WaitForComponentsToFinish()
+ os.Exit(code)
}
func TestSingleTransactionOnInput(t *testing.T) {
@@ -63,7 +54,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
}
db, err := storage.Open(
&config.DatabaseOptions{
- ConnectionString: psqlConnectionString(),
+ ConnectionString: "",
MaxOpenConnections: 1,
MaxIdleConnections: 1,
},
@@ -74,7 +65,9 @@ func TestSingleTransactionOnInput(t *testing.T) {
t.SkipNow()
}
inputter := &input.Inputer{
- DB: db,
+ DB: db,
+ JetStream: js,
+ NATSClient: jc,
}
res := &api.InputRoomEventsResponse{}
inputter.InputRoomEvents(
diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go
index 4e4fe7a2..1c8a89e8 100644
--- a/setup/jetstream/nats.go
+++ b/setup/jetstream/nats.go
@@ -13,12 +13,22 @@ import (
"github.com/sirupsen/logrus"
natsserver "github.com/nats-io/nats-server/v2/server"
+ "github.com/nats-io/nats.go"
natsclient "github.com/nats-io/nats.go"
)
var natsServer *natsserver.Server
var natsServerMutex sync.Mutex
+func PrepareForTests() (*process.ProcessContext, nats.JetStreamContext, *nats.Conn) {
+ cfg := &config.Dendrite{}
+ cfg.Defaults(true)
+ cfg.Global.JetStream.InMemory = true
+ pc := process.NewProcessContext()
+ js, jc := Prepare(pc, &cfg.Global.JetStream)
+ return pc, js, jc
+}
+
func Prepare(process *process.ProcessContext, cfg *config.JetStream) (natsclient.JetStreamContext, *natsclient.Conn) {
// check if we need an in-process NATS Server
if len(cfg.Addresses) != 0 {
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index 86432200..403b50ea 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -1,121 +1,28 @@
package storage_test
-// TODO: Fix these tests
-/*
import (
"context"
- "crypto/ed25519"
- "encoding/json"
"fmt"
- "os"
"testing"
- "time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage"
- "github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/types"
- userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
)
-var (
- ctx = context.Background()
- emptyStateKey = ""
- testOrigin = gomatrixserverlib.ServerName("hollow.knight")
- testRoomID = fmt.Sprintf("!hallownest:%s", testOrigin)
- testUserIDA = fmt.Sprintf("@hornet:%s", testOrigin)
- testUserIDB = fmt.Sprintf("@paleking:%s", testOrigin)
- testUserDeviceA = userapi.Device{
- UserID: testUserIDA,
- ID: "device_id_A",
- DisplayName: "Device A",
- }
- testRoomVersion = gomatrixserverlib.RoomVersionV4
- testKeyID = gomatrixserverlib.KeyID("ed25519:storage_test")
- testPrivateKey = ed25519.NewKeyFromSeed([]byte{
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
- 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
- })
-)
-
-func MustCreateEvent(t *testing.T, roomID string, prevs []*gomatrixserverlib.HeaderedEvent, b *gomatrixserverlib.EventBuilder) *gomatrixserverlib.HeaderedEvent {
- b.RoomID = roomID
- if prevs != nil {
- prevIDs := make([]string, len(prevs))
- for i := range prevs {
- prevIDs[i] = prevs[i].EventID()
- }
- b.PrevEvents = prevIDs
- }
- e, err := b.Build(time.Now(), testOrigin, testKeyID, testPrivateKey, testRoomVersion)
- if err != nil {
- t.Fatalf("failed to build event: %s", err)
- }
- return e.Headered(testRoomVersion)
-}
+var ctx = context.Background()
-func MustCreateDatabase(t *testing.T) storage.Database {
- dbname := fmt.Sprintf("test_%s.db", t.Name())
- if _, err := os.Stat(dbname); err == nil {
- if err = os.Remove(dbname); err != nil {
- t.Fatalf("tried to delete stale test database but failed: %s", err)
- }
- }
- db, err := sqlite3.NewDatabase(&config.DatabaseOptions{
- ConnectionString: config.DataSource(fmt.Sprintf("file:%s", dbname)),
+func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := storage.NewSyncServerDatasource(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
- return db
-}
-
-// Create a list of events which include a create event, join event and some messages.
-func SimpleRoom(t *testing.T, roomID, userA, userB string) (msgs []*gomatrixserverlib.HeaderedEvent, state []*gomatrixserverlib.HeaderedEvent) {
- var events []*gomatrixserverlib.HeaderedEvent
- events = append(events, MustCreateEvent(t, roomID, nil, &gomatrixserverlib.EventBuilder{
- Content: []byte(fmt.Sprintf(`{"room_version":"4","creator":"%s"}`, userA)),
- Type: "m.room.create",
- StateKey: &emptyStateKey,
- Sender: userA,
- Depth: int64(len(events) + 1),
- }))
- state = append(state, events[len(events)-1])
- events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
- Content: []byte(`{"membership":"join"}`),
- Type: "m.room.member",
- StateKey: &userA,
- Sender: userA,
- Depth: int64(len(events) + 1),
- }))
- state = append(state, events[len(events)-1])
- for i := 0; i < 10; i++ {
- events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
- Content: []byte(fmt.Sprintf(`{"body":"Message A %d"}`, i+1)),
- Type: "m.room.message",
- Sender: userA,
- Depth: int64(len(events) + 1),
- }))
- }
- events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
- Content: []byte(`{"membership":"join"}`),
- Type: "m.room.member",
- StateKey: &userB,
- Sender: userB,
- Depth: int64(len(events) + 1),
- }))
- state = append(state, events[len(events)-1])
- for i := 0; i < 10; i++ {
- events = append(events, MustCreateEvent(t, roomID, []*gomatrixserverlib.HeaderedEvent{events[len(events)-1]}, &gomatrixserverlib.EventBuilder{
- Content: []byte(fmt.Sprintf(`{"body":"Message B %d"}`, i+1)),
- Type: "m.room.message",
- Sender: userB,
- Depth: int64(len(events) + 1),
- }))
- }
-
- return events, state
+ return db, close
}
func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserverlib.HeaderedEvent) (positions []types.StreamPosition) {
@@ -138,111 +45,115 @@ func MustWriteEvents(t *testing.T, db storage.Database, events []*gomatrixserver
}
func TestWriteEvents(t *testing.T) {
- t.Parallel()
- db := MustCreateDatabase(t)
- events, _ := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
- MustWriteEvents(t, db, events)
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ t.Parallel()
+ alice := test.NewUser()
+ r := test.NewRoom(t, alice)
+ db, close := MustCreateDatabase(t, dbType)
+ defer close()
+ MustWriteEvents(t, db, r.Events())
+ })
}
-// These tests assert basic functionality of the IncrementalSync and CompleteSync functions.
-func TestSyncResponse(t *testing.T) {
- t.Parallel()
- db := MustCreateDatabase(t)
- events, state := SimpleRoom(t, testRoomID, testUserIDA, testUserIDB)
- positions := MustWriteEvents(t, db, events)
- latest, err := db.SyncPosition(ctx)
- if err != nil {
- t.Fatalf("failed to get SyncPosition: %s", err)
- }
+// These tests assert basic functionality of RecentEvents for PDUs
+func TestRecentEventsPDU(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := MustCreateDatabase(t, dbType)
+ defer close()
+ alice := test.NewUser()
+ var filter gomatrixserverlib.RoomEventFilter
+ filter.Limit = 100
+ r := test.NewRoom(t, alice)
+ r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": "hi"})
+ events := r.Events()
+ positions := MustWriteEvents(t, db, events)
+ latest, err := db.MaxStreamPositionForPDUs(ctx)
+ if err != nil {
+ t.Fatalf("failed to get MaxStreamPositionForPDUs: %s", err)
+ }
- testCases := []struct {
- Name string
- DoSync func() (*types.Response, error)
- WantTimeline []*gomatrixserverlib.HeaderedEvent
- WantState []*gomatrixserverlib.HeaderedEvent
- }{
- // The purpose of this test is to make sure that incremental syncs are including up to the latest events.
- // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
- // It makes sure the response includes the final event.
- {
- Name: "IncrementalSync penultimate",
- DoSync: func() (*types.Response, error) {
- from := types.StreamingToken{ // pretend we are at the penultimate event
- PDUPosition: positions[len(positions)-2],
- }
- res := types.NewResponse()
- return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
- },
- WantTimeline: events[len(events)-1:],
- },
- // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
- // number of returned events. This is critical for big rooms hence the test here.
- {
- Name: "IncrementalSync limited",
- DoSync: func() (*types.Response, error) {
- from := types.StreamingToken{ // pretend we are 10 events behind
- PDUPosition: positions[len(positions)-11],
- }
- res := types.NewResponse()
- // limit is set to 5
- return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
+ testCases := []struct {
+ Name string
+ From types.StreamPosition
+ To types.StreamPosition
+ WantEvents []*gomatrixserverlib.HeaderedEvent
+ WantLimited bool
+ }{
+ // The purpose of this test is to make sure that incremental syncs are including up to the latest events.
+ // It's a basic sanity test that sync works. It creates a `since` token that is on the penultimate event.
+ // It makes sure the response includes the final event.
+ {
+ Name: "IncrementalSync penultimate",
+ From: positions[len(positions)-2], // pretend we are at the penultimate event
+ To: latest,
+ WantEvents: events[len(events)-1:],
+ WantLimited: false,
},
- // want the last 5 events, NOT the last 10.
- WantTimeline: events[len(events)-5:],
- },
- // The purpose of this test is to check that CompleteSync returns all the current state as well as
- // honouring the `numRecentEventsPerRoom` value
- {
- Name: "CompleteSync limited",
- DoSync: func() (*types.Response, error) {
- res := types.NewResponse()
- // limit set to 5
- return db.CompleteSync(ctx, res, testUserDeviceA, 5)
- },
- // want the last 5 events
- WantTimeline: events[len(events)-5:],
- // want all state for the room
- WantState: state,
- },
- // The purpose of this test is to check that CompleteSync can return everything with a high enough
- // `numRecentEventsPerRoom`.
- {
- Name: "CompleteSync",
- DoSync: func() (*types.Response, error) {
- res := types.NewResponse()
- return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
- },
- WantTimeline: events,
- // We want no state at all as that field in /sync is the delta between the token (beginning of time)
- // and the START of the timeline.
- },
- }
+ /*
+ // The purpose of this test is to check that passing a `numRecentEventsPerRoom` correctly limits the
+ // number of returned events. This is critical for big rooms hence the test here.
+ {
+ Name: "IncrementalSync limited",
+ DoSync: func() (*types.Response, error) {
+ from := types.StreamingToken{ // pretend we are 10 events behind
+ PDUPosition: positions[len(positions)-11],
+ }
+ res := types.NewResponse()
+ // limit is set to 5
+ return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
+ },
+ // want the last 5 events, NOT the last 10.
+ WantTimeline: events[len(events)-5:],
+ },
+ // The purpose of this test is to check that CompleteSync returns all the current state as well as
+ // honouring the `numRecentEventsPerRoom` value
+ {
+ Name: "CompleteSync limited",
+ DoSync: func() (*types.Response, error) {
+ res := types.NewResponse()
+ // limit set to 5
+ return db.CompleteSync(ctx, res, testUserDeviceA, 5)
+ },
+ // want the last 5 events
+ WantTimeline: events[len(events)-5:],
+ // want all state for the room
+ WantState: state,
+ },
+ // The purpose of this test is to check that CompleteSync can return everything with a high enough
+ // `numRecentEventsPerRoom`.
+ {
+ Name: "CompleteSync",
+ DoSync: func() (*types.Response, error) {
+ res := types.NewResponse()
+ return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
+ },
+ WantTimeline: events,
+ // We want no state at all as that field in /sync is the delta between the token (beginning of time)
+ // and the START of the timeline.
+ }, */
+ }
- for _, tc := range testCases {
- t.Run(tc.Name, func(st *testing.T) {
- res, err := tc.DoSync()
- if err != nil {
- st.Fatalf("failed to do sync: %s", err)
- }
- next := types.StreamingToken{
- PDUPosition: latest.PDUPosition,
- TypingPosition: latest.TypingPosition,
- ReceiptPosition: latest.ReceiptPosition,
- SendToDevicePosition: latest.SendToDevicePosition,
- }
- if res.NextBatch.String() != next.String() {
- st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
- }
- roomRes, ok := res.Rooms.Join[testRoomID]
- if !ok {
- st.Fatalf("IncrementalSync response missing room %s - response: %+v", testRoomID, res)
- }
- assertEventsEqual(st, "state for "+testRoomID, false, roomRes.State.Events, tc.WantState)
- assertEventsEqual(st, "timeline for "+testRoomID, false, roomRes.Timeline.Events, tc.WantTimeline)
- })
- }
+ for _, tc := range testCases {
+ t.Run(tc.Name, func(st *testing.T) {
+ gotEvents, limited, err := db.RecentEvents(ctx, r.ID, types.Range{
+ From: tc.From,
+ To: tc.To,
+ }, &filter, true, true)
+ if err != nil {
+ st.Fatalf("failed to do sync: %s", err)
+ }
+ if limited != tc.WantLimited {
+ st.Errorf("got limited=%v want %v", limited, tc.WantLimited)
+ }
+ if len(gotEvents) != len(tc.WantEvents) {
+ st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents))
+ }
+ })
+ }
+ })
}
+/*
func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
t.Parallel()
db := MustCreateDatabase(t)
diff --git a/test/db.go b/test/db.go
new file mode 100644
index 00000000..9deec0a8
--- /dev/null
+++ b/test/db.go
@@ -0,0 +1,127 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "database/sql"
+ "fmt"
+ "os"
+ "os/exec"
+ "os/user"
+ "testing"
+)
+
+type DBType int
+
+var DBTypeSQLite DBType = 1
+var DBTypePostgres DBType = 2
+
+var Quiet = false
+
+func createLocalDB(dbName string) string {
+ if !Quiet {
+ fmt.Println("Note: tests require a postgres install accessible to the current user")
+ }
+ createDB := exec.Command("createdb", dbName)
+ if !Quiet {
+ createDB.Stdout = os.Stdout
+ createDB.Stderr = os.Stderr
+ }
+ err := createDB.Run()
+ if err != nil && !Quiet {
+ fmt.Println("createLocalDB returned error:", err)
+ }
+ return dbName
+}
+
+func currentUser() string {
+ user, err := user.Current()
+ if err != nil {
+ if !Quiet {
+ fmt.Println("cannot get current user: ", err)
+ }
+ os.Exit(2)
+ }
+ return user.Username
+}
+
+// Prepare a sqlite or postgres connection string for testing.
+// Returns the connection string to use and a close function which must be called when the test finishes.
+// Calling this function twice will return the same database, which will have data from previous tests
+// unless close() is called.
+// TODO: namespace for concurrent package tests
+func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) {
+ if dbType == DBTypeSQLite {
+ dbname := "dendrite_test.db"
+ return fmt.Sprintf("file:%s", dbname), func() {
+ err := os.Remove(dbname)
+ if err != nil {
+ t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err)
+ }
+ }
+ }
+
+ // Required vars: user and db
+ // We'll try to infer from the local env if they are missing
+ user := os.Getenv("POSTGRES_USER")
+ if user == "" {
+ user = currentUser()
+ }
+ dbName := os.Getenv("POSTGRES_DB")
+ if dbName == "" {
+ dbName = createLocalDB("dendrite_test")
+ }
+ connStr = fmt.Sprintf(
+ "user=%s dbname=%s sslmode=disable",
+ user, dbName,
+ )
+ // optional vars, used in CI
+ password := os.Getenv("POSTGRES_PASSWORD")
+ if password != "" {
+ connStr += fmt.Sprintf(" password=%s", password)
+ }
+ host := os.Getenv("POSTGRES_HOST")
+ if host != "" {
+ connStr += fmt.Sprintf(" host=%s", host)
+ }
+
+ return connStr, func() {
+ // Drop all tables on the database to get a fresh instance
+ db, err := sql.Open("postgres", connStr)
+ if err != nil {
+ t.Fatalf("failed to connect to postgres db '%s': %s", connStr, err)
+ }
+ _, err = db.Exec(`DROP SCHEMA public CASCADE;
+ CREATE SCHEMA public;`)
+ if err != nil {
+ t.Fatalf("failed to cleanup postgres db '%s': %s", connStr, err)
+ }
+ _ = db.Close()
+ }
+}
+
+// Creates subtests with each known DBType
+func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) {
+ dbs := map[string]DBType{
+ "postgres": DBTypePostgres,
+ "sqlite": DBTypeSQLite,
+ }
+ for dbName, dbType := range dbs {
+ dbt := dbType
+ t.Run(dbName, func(tt *testing.T) {
+ testFn(tt, dbt)
+ })
+ }
+}
diff --git a/test/event.go b/test/event.go
new file mode 100644
index 00000000..487b0936
--- /dev/null
+++ b/test/event.go
@@ -0,0 +1,51 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "crypto/ed25519"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type eventMods struct {
+ originServerTS time.Time
+ origin gomatrixserverlib.ServerName
+ stateKey *string
+ unsigned interface{}
+ keyID gomatrixserverlib.KeyID
+ privKey ed25519.PrivateKey
+}
+
+type eventModifier func(e *eventMods)
+
+func WithTimestamp(ts time.Time) eventModifier {
+ return func(e *eventMods) {
+ e.originServerTS = ts
+ }
+}
+
+func WithStateKey(skey string) eventModifier {
+ return func(e *eventMods) {
+ e.stateKey = &skey
+ }
+}
+
+func WithUnsigned(unsigned interface{}) eventModifier {
+ return func(e *eventMods) {
+ e.unsigned = unsigned
+ }
+}
diff --git a/test/room.go b/test/room.go
new file mode 100644
index 00000000..619cb5c9
--- /dev/null
+++ b/test/room.go
@@ -0,0 +1,223 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "crypto/ed25519"
+ "encoding/json"
+ "fmt"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type Preset int
+
+var (
+ PresetNone Preset = 0
+ PresetPrivateChat Preset = 1
+ PresetPublicChat Preset = 2
+ PresetTrustedPrivateChat Preset = 3
+
+ roomIDCounter = int64(0)
+
+ testKeyID = gomatrixserverlib.KeyID("ed25519:test")
+ testPrivateKey = ed25519.NewKeyFromSeed([]byte{
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
+ })
+)
+
+type Room struct {
+ ID string
+ Version gomatrixserverlib.RoomVersion
+ preset Preset
+ creator *User
+
+ authEvents gomatrixserverlib.AuthEvents
+ events []*gomatrixserverlib.HeaderedEvent
+}
+
+// Create a new test room. Automatically creates the initial create events.
+func NewRoom(t *testing.T, creator *User, modifiers ...roomModifier) *Room {
+ t.Helper()
+ counter := atomic.AddInt64(&roomIDCounter, 1)
+
+ // set defaults then let roomModifiers override
+ r := &Room{
+ ID: fmt.Sprintf("!%d:localhost", counter),
+ creator: creator,
+ authEvents: gomatrixserverlib.NewAuthEvents(nil),
+ preset: PresetPublicChat,
+ Version: gomatrixserverlib.RoomVersionV9,
+ }
+ for _, m := range modifiers {
+ m(t, r)
+ }
+ r.insertCreateEvents(t)
+ return r
+}
+
+func (r *Room) insertCreateEvents(t *testing.T) {
+ t.Helper()
+ var joinRule gomatrixserverlib.JoinRuleContent
+ var hisVis gomatrixserverlib.HistoryVisibilityContent
+ plContent := eventutil.InitialPowerLevelsContent(r.creator.ID)
+ switch r.preset {
+ case PresetTrustedPrivateChat:
+ fallthrough
+ case PresetPrivateChat:
+ joinRule.JoinRule = "invite"
+ hisVis.HistoryVisibility = "shared"
+ case PresetPublicChat:
+ joinRule.JoinRule = "public"
+ hisVis.HistoryVisibility = "shared"
+ }
+ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomCreate, map[string]interface{}{
+ "creator": r.creator.ID,
+ "room_version": r.Version,
+ }, WithStateKey(""))
+ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomMember, map[string]interface{}{
+ "membership": "join",
+ }, WithStateKey(r.creator.ID))
+ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey(""))
+ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey(""))
+ r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey(""))
+}
+
+// Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe.
+func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
+ t.Helper()
+ depth := 1 + len(r.events) // depth starts at 1
+
+ // possible event modifiers (optional fields)
+ mod := &eventMods{}
+ for _, m := range mods {
+ m(mod)
+ }
+
+ if mod.privKey == nil {
+ mod.privKey = testPrivateKey
+ }
+ if mod.keyID == "" {
+ mod.keyID = testKeyID
+ }
+ if mod.originServerTS.IsZero() {
+ mod.originServerTS = time.Now()
+ }
+ if mod.origin == "" {
+ mod.origin = gomatrixserverlib.ServerName("localhost")
+ }
+
+ var unsigned gomatrixserverlib.RawJSON
+ var err error
+ if mod.unsigned != nil {
+ unsigned, err = json.Marshal(mod.unsigned)
+ if err != nil {
+ t.Fatalf("CreateEvent[%s]: failed to marshal unsigned field: %s", eventType, err)
+ }
+ }
+
+ builder := &gomatrixserverlib.EventBuilder{
+ Sender: creator.ID,
+ RoomID: r.ID,
+ Type: eventType,
+ StateKey: mod.stateKey,
+ Depth: int64(depth),
+ Unsigned: unsigned,
+ }
+ err = builder.SetContent(content)
+ if err != nil {
+ t.Fatalf("CreateEvent[%s]: failed to SetContent: %s", eventType, err)
+ }
+ if depth > 1 {
+ builder.PrevEvents = []gomatrixserverlib.EventReference{r.events[len(r.events)-1].EventReference()}
+ }
+
+ eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
+ if err != nil {
+ t.Fatalf("CreateEvent[%s]: failed to StateNeededForEventBuilder: %s", eventType, err)
+ }
+ refs, err := eventsNeeded.AuthEventReferences(&r.authEvents)
+ if err != nil {
+ t.Fatalf("CreateEvent[%s]: failed to AuthEventReferences: %s", eventType, err)
+ }
+ builder.AuthEvents = refs
+ ev, err := builder.Build(
+ mod.originServerTS, mod.origin, mod.keyID,
+ mod.privKey, r.Version,
+ )
+ if err != nil {
+ t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err)
+ }
+ if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil {
+ t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err)
+ }
+ return ev.Headered(r.Version)
+}
+
+// Add a new event to this room DAG. Not thread-safe.
+func (r *Room) InsertEvent(t *testing.T, he *gomatrixserverlib.HeaderedEvent) {
+ t.Helper()
+ // Add the event to the list of auth events
+ r.events = append(r.events, he)
+ if he.StateKey() != nil {
+ err := r.authEvents.AddEvent(he.Unwrap())
+ if err != nil {
+ t.Fatalf("InsertEvent: failed to add event to auth events: %s", err)
+ }
+ }
+}
+
+func (r *Room) Events() []*gomatrixserverlib.HeaderedEvent {
+ return r.events
+}
+
+func (r *Room) CreateAndInsert(t *testing.T, creator *User, eventType string, content interface{}, mods ...eventModifier) *gomatrixserverlib.HeaderedEvent {
+ t.Helper()
+ he := r.CreateEvent(t, creator, eventType, content, mods...)
+ r.InsertEvent(t, he)
+ return he
+}
+
+// All room modifiers are below
+
+type roomModifier func(t *testing.T, r *Room)
+
+func RoomPreset(p Preset) roomModifier {
+ return func(t *testing.T, r *Room) {
+ switch p {
+ case PresetPrivateChat:
+ fallthrough
+ case PresetPublicChat:
+ fallthrough
+ case PresetTrustedPrivateChat:
+ fallthrough
+ case PresetNone:
+ r.preset = p
+ default:
+ t.Errorf("invalid RoomPreset: %v", p)
+ }
+ }
+}
+
+func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier {
+ return func(t *testing.T, r *Room) {
+ r.Version = ver
+ }
+}
diff --git a/test/user.go b/test/user.go
new file mode 100644
index 00000000..41a66e1c
--- /dev/null
+++ b/test/user.go
@@ -0,0 +1,36 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package test
+
+import (
+ "fmt"
+ "sync/atomic"
+)
+
+var (
+ userIDCounter = int64(0)
+)
+
+type User struct {
+ ID string
+}
+
+func NewUser() *User {
+ counter := atomic.AddInt64(&userIDCounter, 1)
+ u := &User{
+ ID: fmt.Sprintf("@%d:localhost", counter),
+ }
+ return u
+}