aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/tables/current_room_state_test.go
blob: 7d4ec812ce8830f9738572300f31d213a8677697 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package tables_test

import (
	"context"
	"database/sql"
	"fmt"
	"testing"

	"github.com/matrix-org/dendrite/internal/sqlutil"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/dendrite/syncapi/storage/postgres"
	"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
	"github.com/matrix-org/dendrite/syncapi/storage/tables"
	"github.com/matrix-org/dendrite/syncapi/synctypes"
	"github.com/matrix-org/dendrite/syncapi/types"
	"github.com/matrix-org/dendrite/test"
	"github.com/matrix-org/gomatrixserverlib/spec"
)

func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) {
	t.Helper()
	connStr, close := test.PrepareDBConnectionString(t, dbType)
	db, err := sqlutil.Open(&config.DatabaseOptions{
		ConnectionString: config.DataSource(connStr),
	}, sqlutil.NewExclusiveWriter())
	if err != nil {
		t.Fatalf("failed to open db: %s", err)
	}

	var tab tables.CurrentRoomState
	switch dbType {
	case test.DBTypePostgres:
		tab, err = postgres.NewPostgresCurrentRoomStateTable(db)
	case test.DBTypeSQLite:
		var stream sqlite3.StreamIDStatements
		if err = stream.Prepare(db); err != nil {
			t.Fatalf("failed to prepare stream stmts: %s", err)
		}
		tab, err = sqlite3.NewSqliteCurrentRoomStateTable(db, &stream)
	}
	if err != nil {
		t.Fatalf("failed to make new table: %s", err)
	}
	return tab, db, close
}

func TestCurrentRoomStateTable(t *testing.T) {
	ctx := context.Background()
	alice := test.NewUser(t)
	room := test.NewRoom(t, alice)
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		tab, db, close := newCurrentRoomStateTable(t, dbType)
		defer close()
		events := room.CurrentState()
		err := sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
			for i, ev := range events {
				err := tab.UpsertRoomState(ctx, txn, ev, nil, types.StreamPosition(i))
				if err != nil {
					return fmt.Errorf("failed to UpsertRoomState: %w", err)
				}
			}
			wantEventIDs := []string{
				events[0].EventID(), events[1].EventID(), events[2].EventID(), events[3].EventID(),
			}
			gotEvents, err := tab.SelectEventsWithEventIDs(ctx, txn, wantEventIDs)
			if err != nil {
				return fmt.Errorf("failed to SelectEventsWithEventIDs: %w", err)
			}
			if len(gotEvents) != len(wantEventIDs) {
				return fmt.Errorf("SelectEventsWithEventIDs\ngot %d, want %d results", len(gotEvents), len(wantEventIDs))
			}
			gotEventIDs := make(map[string]struct{}, len(gotEvents))
			for _, event := range gotEvents {
				if event.ExcludeFromSync {
					return fmt.Errorf("SelectEventsWithEventIDs ExcludeFromSync should be false for current room state event %+v", event)
				}
				gotEventIDs[event.EventID()] = struct{}{}
			}
			for _, id := range wantEventIDs {
				if _, ok := gotEventIDs[id]; !ok {
					return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id)
				}
			}

			testCurrentState(t, ctx, txn, tab, room)

			return nil
		})
		if err != nil {
			t.Fatalf("err: %v", err)
		}
	})
}

func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables.CurrentRoomState, room *test.Room) {
	t.Run("test currentState", func(t *testing.T) {
		// returns the complete state of the room with a default filter
		filter := synctypes.DefaultStateFilter()
		evs, err := tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
		if err != nil {
			t.Fatal(err)
		}
		expectCount := 5
		if gotCount := len(evs); gotCount != expectCount {
			t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
		}
		// When lazy loading, we expect no membership event, so only 4 events
		filter.LazyLoadMembers = true
		expectCount = 4
		evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
		if err != nil {
			t.Fatal(err)
		}
		if gotCount := len(evs); gotCount != expectCount {
			t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
		}
		// same as above, but with existing NotTypes defined
		notTypes := []string{spec.MRoomMember}
		filter.NotTypes = &notTypes
		evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil)
		if err != nil {
			t.Fatal(err)
		}
		if gotCount := len(evs); gotCount != expectCount {
			t.Fatalf("expected %d state events, got %d", expectCount, gotCount)
		}
	})

}