aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--roomserver/storage/postgres/events_table.go6
-rw-r--r--roomserver/storage/postgres/room_aliases_table.go6
-rw-r--r--roomserver/storage/postgres/rooms_table.go16
-rw-r--r--roomserver/storage/postgres/state_block_table.go53
-rw-r--r--roomserver/storage/postgres/state_block_table_test.go86
-rw-r--r--roomserver/storage/postgres/state_snapshot_table.go10
-rw-r--r--roomserver/storage/postgres/storage.go16
-rw-r--r--roomserver/storage/shared/storage.go2
-rw-r--r--roomserver/storage/sqlite3/events_table.go4
-rw-r--r--roomserver/storage/sqlite3/room_aliases_table.go6
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go16
-rw-r--r--roomserver/storage/sqlite3/state_block_table.go49
-rw-r--r--roomserver/storage/sqlite3/state_block_table_test.go86
-rw-r--r--roomserver/storage/sqlite3/state_snapshot_table.go10
-rw-r--r--roomserver/storage/sqlite3/storage.go16
-rw-r--r--roomserver/storage/tables/interface.go2
-rw-r--r--roomserver/storage/tables/room_aliases_table_test.go96
-rw-r--r--roomserver/storage/tables/rooms_table_test.go128
-rw-r--r--roomserver/storage/tables/state_block_table_test.go92
-rw-r--r--roomserver/storage/tables/state_snapshot_table_test.go86
-rw-r--r--roomserver/types/types.go33
-rw-r--r--roomserver/types/types_test.go64
22 files changed, 570 insertions, 313 deletions
diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go
index 86d226ce..a4d05756 100644
--- a/roomserver/storage/postgres/events_table.go
+++ b/roomserver/storage/postgres/events_table.go
@@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
- tuples := stateKeyTupleSorter(stateKeyTuples)
+ tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
- eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
+ eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
- rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
+ rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray))
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go
index d13df8e7..a84929f6 100644
--- a/roomserver/storage/postgres/room_aliases_table.go
+++ b/roomserver/storage/postgres/room_aliases_table.go
@@ -61,12 +61,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
-func createRoomAliasesTable(db *sql.DB) error {
+func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema)
return err
}
-func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
+func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{}
return s, sqlutil.StatementList{
@@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
var aliases []string
+ var alias string
for rows.Next() {
- var alias string
if err = rows.Scan(&alias); err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go
index b2685084..24362af7 100644
--- a/roomserver/storage/postgres/rooms_table.go
+++ b/roomserver/storage/postgres/rooms_table.go
@@ -95,12 +95,12 @@ type roomStatements struct {
bulkSelectRoomNIDsStmt *sql.Stmt
}
-func createRoomsTable(db *sql.DB) error {
+func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema)
return err
}
-func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
+func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{}
return s, sqlutil.StatementList{
@@ -117,7 +117,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
-func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
+func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
@@ -125,8 +125,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
+ var roomID string
for rows.Next() {
- var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
+ var roomNID types.RoomNID
+ var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() {
- var roomNID types.RoomNID
- var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
@@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string
+ var roomID string
for rows.Next() {
- var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
+ var roomNID types.RoomNID
for rows.Next() {
- var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go
index 6f8f9e1b..5af48f03 100644
--- a/roomserver/storage/postgres/state_block_table.go
+++ b/roomserver/storage/postgres/state_block_table.go
@@ -19,7 +19,6 @@ import (
"context"
"database/sql"
"fmt"
- "sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
@@ -71,12 +70,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt
}
-func createStateBlockTable(db *sql.DB) error {
+func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema)
return err
}
-func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
+func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{}
return s, sqlutil.StatementList{
@@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
- var nids types.EventNIDs
- for _, e := range entries {
- nids = append(nids, e.EventNID)
+ nids := make(types.EventNIDs, entries.Len())
+ for i := range entries {
+ nids[i] = entries[i].EventNID
}
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
@@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
+ var stateBlockNID types.StateBlockNID
+ var result pq.Int64Array
for ; rows.Next(); i++ {
- var stateBlockNID types.StateBlockNID
- var result pq.Int64Array
if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
- r := []types.EventNID{}
- for _, e := range result {
- r = append(r, types.EventNID(e))
+ r := make([]types.EventNID, len(result))
+ for x := range result {
+ r[x] = types.EventNID(result[x])
}
results[i] = r
}
@@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
}
return pq.Int64Array(nids)
}
-
-type stateKeyTupleSorter []types.StateKeyTuple
-
-func (s stateKeyTupleSorter) Len() int { return len(s) }
-func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
-func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-
-// Check whether a tuple is in the list. Assumes that the list is sorted.
-func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
- i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
- return i < len(s) && s[i] == value
-}
-
-// List the unique eventTypeNIDs and eventStateKeyNIDs.
-// Assumes that the list is sorted.
-func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
- eventTypeNIDs = make(pq.Int64Array, len(s))
- eventStateKeyNIDs = make(pq.Int64Array, len(s))
- for i := range s {
- eventTypeNIDs[i] = int64(s[i].EventTypeNID)
- eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
- }
- eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
- eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
- return
-}
-
-type int64Sorter []int64
-
-func (s int64Sorter) Len() int { return len(s) }
-func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
-func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/roomserver/storage/postgres/state_block_table_test.go b/roomserver/storage/postgres/state_block_table_test.go
deleted file mode 100644
index a0e2ec95..00000000
--- a/roomserver/storage/postgres/state_block_table_test.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2017-2018 New Vector Ltd
-// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "sort"
- "testing"
-
- "github.com/matrix-org/dendrite/roomserver/types"
-)
-
-func TestStateKeyTupleSorter(t *testing.T) {
- input := stateKeyTupleSorter{
- {EventTypeNID: 1, EventStateKeyNID: 2},
- {EventTypeNID: 1, EventStateKeyNID: 4},
- {EventTypeNID: 2, EventStateKeyNID: 2},
- {EventTypeNID: 1, EventStateKeyNID: 1},
- }
- want := []types.StateKeyTuple{
- {EventTypeNID: 1, EventStateKeyNID: 1},
- {EventTypeNID: 1, EventStateKeyNID: 2},
- {EventTypeNID: 1, EventStateKeyNID: 4},
- {EventTypeNID: 2, EventStateKeyNID: 2},
- }
- doNotWant := []types.StateKeyTuple{
- {EventTypeNID: 0, EventStateKeyNID: 0},
- {EventTypeNID: 1, EventStateKeyNID: 3},
- {EventTypeNID: 2, EventStateKeyNID: 1},
- {EventTypeNID: 3, EventStateKeyNID: 1},
- }
- wantTypeNIDs := []int64{1, 2}
- wantStateKeyNIDs := []int64{1, 2, 4}
-
- // Sort the input and check it's in the right order.
- sort.Sort(input)
- gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
-
- for i := range want {
- if input[i] != want[i] {
- t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
- }
-
- if !input.contains(want[i]) {
- t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
- }
- }
-
- for i := range doNotWant {
- if input.contains(doNotWant[i]) {
- t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
- }
- }
-
- if len(wantTypeNIDs) != len(gotTypeNIDs) {
- t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
- }
-
- for i := range wantTypeNIDs {
- if wantTypeNIDs[i] != gotTypeNIDs[i] {
- t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
- }
- }
-
- if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
- t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
- }
-
- for i := range wantStateKeyNIDs {
- if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
- t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
- }
- }
-}
diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go
index 8ed88603..a24b7f3f 100644
--- a/roomserver/storage/postgres/state_snapshot_table.go
+++ b/roomserver/storage/postgres/state_snapshot_table.go
@@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
-func createStateSnapshotTable(db *sql.DB) error {
+func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}
-func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
+func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{
@@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) {
nids = nids[:util.SortAndUnique(nids)]
- var id int64
- err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
+ err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
if err != nil {
return 0, err
}
- stateNID = types.StateSnapshotNID(id)
return
}
@@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
+ var stateBlockNIDs pq.Int64Array
for ; rows.Next(); i++ {
result := &results[i]
- var stateBlockNIDs pq.Int64Array
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 88df7200..70ea4d8b 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil {
return err
}
- if err := createRoomsTable(db); err != nil {
+ if err := CreateRoomsTable(db); err != nil {
return err
}
- if err := createStateBlockTable(db); err != nil {
+ if err := CreateStateBlockTable(db); err != nil {
return err
}
- if err := createStateSnapshotTable(db); err != nil {
+ if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := CreatePrevEventsTable(db); err != nil {
return err
}
- if err := createRoomAliasesTable(db); err != nil {
+ if err := CreateRoomAliasesTable(db); err != nil {
return err
}
if err := CreateInvitesTable(db); err != nil {
@@ -128,15 +128,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
- rooms, err := prepareRoomsTable(db)
+ rooms, err := PrepareRoomsTable(db)
if err != nil {
return err
}
- stateBlock, err := prepareStateBlockTable(db)
+ stateBlock, err := PrepareStateBlockTable(db)
if err != nil {
return err
}
- stateSnapshot, err := prepareStateSnapshotTable(db)
+ stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}
@@ -144,7 +144,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
- roomAliases, err := prepareRoomAliasesTable(db)
+ roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil {
return err
}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 252e94c7..cc4a9fff 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -1216,7 +1216,7 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
- return d.RoomsTable.SelectRoomIDs(ctx, nil)
+ return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
}
// ForgetRoom sets a users room to forgotten
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index feb06150..1dda34c3 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -247,9 +247,9 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
- tuples := stateKeyTupleSorter(stateKeyTuples)
+ tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
- eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
+ eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
params := make([]interface{}, 0, len(eventNIDs)+len(eventTypeNIDArray)+len(eventStateKeyNIDArray))
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
for _, v := range eventNIDs {
diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go
index 7c7bead9..3bdbbaa3 100644
--- a/roomserver/storage/sqlite3/room_aliases_table.go
+++ b/roomserver/storage/sqlite3/room_aliases_table.go
@@ -63,12 +63,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
-func createRoomAliasesTable(db *sql.DB) error {
+func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema)
return err
}
-func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
+func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
db: db,
}
@@ -113,8 +113,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
+ var alias string
for rows.Next() {
- var alias string
if err = rows.Scan(&alias); err != nil {
return
}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index cd60c678..03ad4b3d 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -86,12 +86,12 @@ type roomStatements struct {
selectRoomIDsStmt *sql.Stmt
}
-func createRoomsTable(db *sql.DB) error {
+func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema)
return err
}
-func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
+func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{
db: db,
}
@@ -108,7 +108,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
-func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
+func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
@@ -116,8 +116,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
+ var roomID string
for rows.Next() {
- var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@@ -241,9 +241,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
+ var roomNID types.RoomNID
+ var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() {
- var roomNID types.RoomNID
- var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
@@ -270,8 +270,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string
+ var roomID string
for rows.Next() {
- var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@@ -298,8 +298,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
+ var roomNID types.RoomNID
for rows.Next() {
- var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go
index 3c829cdc..4e67d4da 100644
--- a/roomserver/storage/sqlite3/state_block_table.go
+++ b/roomserver/storage/sqlite3/state_block_table.go
@@ -20,7 +20,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
- "sort"
"strings"
"github.com/matrix-org/dendrite/internal"
@@ -64,12 +63,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt
}
-func createStateBlockTable(db *sql.DB) error {
+func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema)
return err
}
-func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
+func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
}
@@ -85,9 +84,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
- nids := types.EventNIDs{} // zero slice to not store 'null' in the DB
- for _, e := range entries {
- nids = append(nids, e.EventNID)
+ nids := make(types.EventNIDs, entries.Len())
+ for i := range entries {
+ nids[i] = entries[i].EventNID
}
js, err := json.Marshal(nids)
if err != nil {
@@ -122,13 +121,13 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
+ var stateBlockNID types.StateBlockNID
+ var result json.RawMessage
for ; rows.Next(); i++ {
- var stateBlockNID types.StateBlockNID
- var result json.RawMessage
if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
- r := []types.EventNID{}
+ var r []types.EventNID
if err = json.Unmarshal(result, &r); err != nil {
return nil, fmt.Errorf("json.Unmarshal: %w", err)
}
@@ -142,35 +141,3 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
}
return results, err
}
-
-type stateKeyTupleSorter []types.StateKeyTuple
-
-func (s stateKeyTupleSorter) Len() int { return len(s) }
-func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
-func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-
-// Check whether a tuple is in the list. Assumes that the list is sorted.
-func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
- i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
- return i < len(s) && s[i] == value
-}
-
-// List the unique eventTypeNIDs and eventStateKeyNIDs.
-// Assumes that the list is sorted.
-func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
- eventTypeNIDs = make([]int64, len(s))
- eventStateKeyNIDs = make([]int64, len(s))
- for i := range s {
- eventTypeNIDs[i] = int64(s[i].EventTypeNID)
- eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
- }
- eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
- eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
- return
-}
-
-type int64Sorter []int64
-
-func (s int64Sorter) Len() int { return len(s) }
-func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
-func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/roomserver/storage/sqlite3/state_block_table_test.go b/roomserver/storage/sqlite3/state_block_table_test.go
deleted file mode 100644
index 98439f5c..00000000
--- a/roomserver/storage/sqlite3/state_block_table_test.go
+++ /dev/null
@@ -1,86 +0,0 @@
-// Copyright 2017-2018 New Vector Ltd
-// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "sort"
- "testing"
-
- "github.com/matrix-org/dendrite/roomserver/types"
-)
-
-func TestStateKeyTupleSorter(t *testing.T) {
- input := stateKeyTupleSorter{
- {EventTypeNID: 1, EventStateKeyNID: 2},
- {EventTypeNID: 1, EventStateKeyNID: 4},
- {EventTypeNID: 2, EventStateKeyNID: 2},
- {EventTypeNID: 1, EventStateKeyNID: 1},
- }
- want := []types.StateKeyTuple{
- {EventTypeNID: 1, EventStateKeyNID: 1},
- {EventTypeNID: 1, EventStateKeyNID: 2},
- {EventTypeNID: 1, EventStateKeyNID: 4},
- {EventTypeNID: 2, EventStateKeyNID: 2},
- }
- doNotWant := []types.StateKeyTuple{
- {EventTypeNID: 0, EventStateKeyNID: 0},
- {EventTypeNID: 1, EventStateKeyNID: 3},
- {EventTypeNID: 2, EventStateKeyNID: 1},
- {EventTypeNID: 3, EventStateKeyNID: 1},
- }
- wantTypeNIDs := []int64{1, 2}
- wantStateKeyNIDs := []int64{1, 2, 4}
-
- // Sort the input and check it's in the right order.
- sort.Sort(input)
- gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
-
- for i := range want {
- if input[i] != want[i] {
- t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
- }
-
- if !input.contains(want[i]) {
- t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
- }
- }
-
- for i := range doNotWant {
- if input.contains(doNotWant[i]) {
- t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
- }
- }
-
- if len(wantTypeNIDs) != len(gotTypeNIDs) {
- t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
- }
-
- for i := range wantTypeNIDs {
- if wantTypeNIDs[i] != gotTypeNIDs[i] {
- t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
- }
- }
-
- if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
- t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
- }
-
- for i := range wantStateKeyNIDs {
- if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
- t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
- }
- }
-}
diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go
index 1f5e9ee3..b8136b75 100644
--- a/roomserver/storage/sqlite3/state_snapshot_table.go
+++ b/roomserver/storage/sqlite3/state_snapshot_table.go
@@ -68,12 +68,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
-func createStateSnapshotTable(db *sql.DB) error {
+func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}
-func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
+func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
db: db,
}
@@ -96,12 +96,10 @@ func (s *stateSnapshotStatements) InsertState(
return
}
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
- var id int64
- err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&id)
+ err = insertStmt.QueryRowContext(ctx, stateBlockNIDs.Hash(), int64(roomNID), string(stateBlockNIDsJSON)).Scan(&stateNID)
if err != nil {
return 0, err
}
- stateNID = types.StateSnapshotNID(id)
return
}
@@ -127,9 +125,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
+ var stateBlockNIDsJSON string
for ; rows.Next(); i++ {
result := &results[i]
- var stateBlockNIDsJSON string
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index a4e32d52..8325fdad 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -89,19 +89,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil {
return err
}
- if err := createRoomsTable(db); err != nil {
+ if err := CreateRoomsTable(db); err != nil {
return err
}
- if err := createStateBlockTable(db); err != nil {
+ if err := CreateStateBlockTable(db); err != nil {
return err
}
- if err := createStateSnapshotTable(db); err != nil {
+ if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := CreatePrevEventsTable(db); err != nil {
return err
}
- if err := createRoomAliasesTable(db); err != nil {
+ if err := CreateRoomAliasesTable(db); err != nil {
return err
}
if err := CreateInvitesTable(db); err != nil {
@@ -137,15 +137,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
- rooms, err := prepareRoomsTable(db)
+ rooms, err := PrepareRoomsTable(db)
if err != nil {
return err
}
- stateBlock, err := prepareStateBlockTable(db)
+ stateBlock, err := PrepareStateBlockTable(db)
if err != nil {
return err
}
- stateSnapshot, err := prepareStateSnapshotTable(db)
+ stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}
@@ -153,7 +153,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
- roomAliases, err := prepareRoomAliasesTable(db)
+ roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil {
return err
}
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index 95609787..116e11c4 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -72,7 +72,7 @@ type Rooms interface {
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
- SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
+ SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
}
diff --git a/roomserver/storage/tables/room_aliases_table_test.go b/roomserver/storage/tables/room_aliases_table_test.go
new file mode 100644
index 00000000..8fb57d5a
--- /dev/null
+++ b/roomserver/storage/tables/room_aliases_table_test.go
@@ -0,0 +1,96 @@
+package tables_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/storage/postgres"
+ "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/stretchr/testify/assert"
+)
+
+func mustCreateRoomAliasesTable(t *testing.T, dbType test.DBType) (tab tables.RoomAliases, close func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+ switch dbType {
+ case test.DBTypePostgres:
+ err = postgres.CreateRoomAliasesTable(db)
+ assert.NoError(t, err)
+ tab, err = postgres.PrepareRoomAliasesTable(db)
+ case test.DBTypeSQLite:
+ err = sqlite3.CreateRoomAliasesTable(db)
+ assert.NoError(t, err)
+ tab, err = sqlite3.PrepareRoomAliasesTable(db)
+ }
+ assert.NoError(t, err)
+
+ return tab, close
+}
+
+func TestRoomAliasesTable(t *testing.T) {
+ alice := test.NewUser()
+ room := test.NewRoom(t, alice)
+ room2 := test.NewRoom(t, alice)
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, close := mustCreateRoomAliasesTable(t, dbType)
+ defer close()
+ alias, alias2, alias3 := "#alias:localhost", "#alias2:localhost", "#alias3:localhost"
+ // insert aliases
+ err := tab.InsertRoomAlias(ctx, nil, alias, room.ID, alice.ID)
+ assert.NoError(t, err)
+
+ err = tab.InsertRoomAlias(ctx, nil, alias2, room.ID, alice.ID)
+ assert.NoError(t, err)
+
+ err = tab.InsertRoomAlias(ctx, nil, alias3, room2.ID, alice.ID)
+ assert.NoError(t, err)
+
+ // verify we can get the roomID for the alias
+ roomID, err := tab.SelectRoomIDFromAlias(ctx, nil, alias)
+ assert.NoError(t, err)
+ assert.Equal(t, room.ID, roomID)
+
+ // .. and the creator
+ creator, err := tab.SelectCreatorIDFromAlias(ctx, nil, alias)
+ assert.NoError(t, err)
+ assert.Equal(t, alice.ID, creator)
+
+ creator, err = tab.SelectCreatorIDFromAlias(ctx, nil, "#doesntexist:localhost")
+ assert.NoError(t, err)
+ assert.Equal(t, "", creator)
+
+ roomID, err = tab.SelectRoomIDFromAlias(ctx, nil, "#doesntexist:localhost")
+ assert.NoError(t, err)
+ assert.Equal(t, "", roomID)
+
+ // get all aliases for a room
+ aliases, err := tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{alias, alias2}, aliases)
+
+ // delete an alias and verify it's deleted
+ err = tab.DeleteRoomAlias(ctx, nil, alias2)
+ assert.NoError(t, err)
+
+ aliases, err = tab.SelectAliasesFromRoomID(ctx, nil, room.ID)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{alias}, aliases)
+
+ // deleting the same alias should be a no-op
+ err = tab.DeleteRoomAlias(ctx, nil, alias2)
+ assert.NoError(t, err)
+
+ // Delete non-existent alias should be a no-op
+ err = tab.DeleteRoomAlias(ctx, nil, "#doesntexist:localhost")
+ assert.NoError(t, err)
+ })
+}
diff --git a/roomserver/storage/tables/rooms_table_test.go b/roomserver/storage/tables/rooms_table_test.go
new file mode 100644
index 00000000..9872fb80
--- /dev/null
+++ b/roomserver/storage/tables/rooms_table_test.go
@@ -0,0 +1,128 @@
+package tables_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/storage/postgres"
+ "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/util"
+ "github.com/stretchr/testify/assert"
+)
+
+func mustCreateRoomsTable(t *testing.T, dbType test.DBType) (tab tables.Rooms, close func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+ switch dbType {
+ case test.DBTypePostgres:
+ err = postgres.CreateRoomsTable(db)
+ assert.NoError(t, err)
+ tab, err = postgres.PrepareRoomsTable(db)
+ case test.DBTypeSQLite:
+ err = sqlite3.CreateRoomsTable(db)
+ assert.NoError(t, err)
+ tab, err = sqlite3.PrepareRoomsTable(db)
+ }
+ assert.NoError(t, err)
+
+ return tab, close
+}
+
+func TestRoomsTable(t *testing.T) {
+ alice := test.NewUser()
+ room := test.NewRoom(t, alice)
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, close := mustCreateRoomsTable(t, dbType)
+ defer close()
+
+ wantRoomNID, err := tab.InsertRoomNID(ctx, nil, room.ID, room.Version)
+ assert.NoError(t, err)
+
+ // Create dummy room
+ _, err = tab.InsertRoomNID(ctx, nil, util.RandomString(16), room.Version)
+ assert.NoError(t, err)
+
+ gotRoomNID, err := tab.SelectRoomNID(ctx, nil, room.ID)
+ assert.NoError(t, err)
+ assert.Equal(t, wantRoomNID, gotRoomNID)
+
+ // Ensure non existent roomNID errors
+ roomNID, err := tab.SelectRoomNID(ctx, nil, "!doesnotexist:localhost")
+ assert.Error(t, err)
+ assert.Equal(t, types.RoomNID(0), roomNID)
+
+ roomInfo, err := tab.SelectRoomInfo(ctx, nil, room.ID)
+ assert.NoError(t, err)
+ assert.Equal(t, &types.RoomInfo{
+ RoomNID: wantRoomNID,
+ RoomVersion: room.Version,
+ StateSnapshotNID: 0,
+ IsStub: true, // there are no latestEventNIDs
+ }, roomInfo)
+
+ roomInfo, err = tab.SelectRoomInfo(ctx, nil, "!doesnotexist:localhost")
+ assert.NoError(t, err)
+ assert.Nil(t, roomInfo)
+
+ // There are no rooms with latestEventNIDs yet
+ roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil)
+ assert.NoError(t, err)
+ assert.Equal(t, 0, len(roomIDs))
+
+ roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
+ assert.NoError(t, err)
+ assert.Equal(t, roomVersions[wantRoomNID], room.Version)
+ // Room does not exist
+ _, ok := roomVersions[1337]
+ assert.False(t, ok)
+
+ roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
+ assert.NoError(t, err)
+ assert.Equal(t, []string{room.ID}, roomIDs)
+
+ roomNIDs, err := tab.BulkSelectRoomNIDs(ctx, nil, []string{room.ID, "!doesnotexist:localhost"})
+ assert.NoError(t, err)
+ assert.Equal(t, []types.RoomNID{wantRoomNID}, roomNIDs)
+
+ wantEventNIDs := []types.EventNID{1, 2, 3}
+ lastEventSentNID := types.EventNID(3)
+ stateSnapshotNID := types.StateSnapshotNID(1)
+ // make the room "usable"
+ err = tab.UpdateLatestEventNIDs(ctx, nil, wantRoomNID, wantEventNIDs, lastEventSentNID, stateSnapshotNID)
+ assert.NoError(t, err)
+
+ roomInfo, err = tab.SelectRoomInfo(ctx, nil, room.ID)
+ assert.NoError(t, err)
+ assert.Equal(t, &types.RoomInfo{
+ RoomNID: wantRoomNID,
+ RoomVersion: room.Version,
+ StateSnapshotNID: 1,
+ IsStub: false,
+ }, roomInfo)
+
+ eventNIDs, snapshotNID, err := tab.SelectLatestEventNIDs(ctx, nil, wantRoomNID)
+ assert.NoError(t, err)
+ assert.Equal(t, wantEventNIDs, eventNIDs)
+ assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
+
+ // Again, doesn't exist
+ _, _, err = tab.SelectLatestEventNIDs(ctx, nil, 1337)
+ assert.Error(t, err)
+
+ eventNIDs, eventNID, snapshotNID, err := tab.SelectLatestEventsNIDsForUpdate(ctx, nil, wantRoomNID)
+ assert.NoError(t, err)
+ assert.Equal(t, wantEventNIDs, eventNIDs)
+ assert.Equal(t, types.EventNID(3), eventNID)
+ assert.Equal(t, types.StateSnapshotNID(1), snapshotNID)
+ })
+}
diff --git a/roomserver/storage/tables/state_block_table_test.go b/roomserver/storage/tables/state_block_table_test.go
new file mode 100644
index 00000000..de0b420b
--- /dev/null
+++ b/roomserver/storage/tables/state_block_table_test.go
@@ -0,0 +1,92 @@
+package tables_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/storage/postgres"
+ "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/stretchr/testify/assert"
+)
+
+func mustCreateStateBlockTable(t *testing.T, dbType test.DBType) (tab tables.StateBlock, close func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+ switch dbType {
+ case test.DBTypePostgres:
+ err = postgres.CreateStateBlockTable(db)
+ assert.NoError(t, err)
+ tab, err = postgres.PrepareStateBlockTable(db)
+ case test.DBTypeSQLite:
+ err = sqlite3.CreateStateBlockTable(db)
+ assert.NoError(t, err)
+ tab, err = sqlite3.PrepareStateBlockTable(db)
+ }
+ assert.NoError(t, err)
+
+ return tab, close
+}
+
+func TestStateBlockTable(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, close := mustCreateStateBlockTable(t, dbType)
+ defer close()
+
+ // generate some dummy data
+ var entries types.StateEntries
+ for i := 0; i < 100; i++ {
+ entry := types.StateEntry{
+ EventNID: types.EventNID(i),
+ }
+ entries = append(entries, entry)
+ }
+ stateBlockNID, err := tab.BulkInsertStateData(ctx, nil, entries)
+ assert.NoError(t, err)
+ assert.Equal(t, types.StateBlockNID(1), stateBlockNID)
+
+ // generate a different hash, to get a new StateBlockNID
+ var entries2 types.StateEntries
+ for i := 100; i < 300; i++ {
+ entry := types.StateEntry{
+ EventNID: types.EventNID(i),
+ }
+ entries2 = append(entries2, entry)
+ }
+ stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
+ assert.NoError(t, err)
+ assert.Equal(t, types.StateBlockNID(2), stateBlockNID)
+
+ eventNIDs, err := tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 2})
+ assert.NoError(t, err)
+ assert.Equal(t, len(entries), len(eventNIDs[0]))
+ assert.Equal(t, len(entries2), len(eventNIDs[1]))
+
+ // try to get a StateBlockNID which does not exist
+ _, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{5})
+ assert.Error(t, err)
+
+ // This should return an error, since we can only retrieve 1 StateBlock
+ _, err = tab.BulkSelectStateBlockEntries(ctx, nil, types.StateBlockNIDs{1, 5})
+ assert.Error(t, err)
+
+ for i := 0; i < 65555; i++ {
+ entry := types.StateEntry{
+ EventNID: types.EventNID(i),
+ }
+ entries2 = append(entries2, entry)
+ }
+ stateBlockNID, err = tab.BulkInsertStateData(ctx, nil, entries2)
+ assert.NoError(t, err)
+ assert.Equal(t, types.StateBlockNID(3), stateBlockNID)
+ })
+}
diff --git a/roomserver/storage/tables/state_snapshot_table_test.go b/roomserver/storage/tables/state_snapshot_table_test.go
new file mode 100644
index 00000000..dcdb5d8f
--- /dev/null
+++ b/roomserver/storage/tables/state_snapshot_table_test.go
@@ -0,0 +1,86 @@
+package tables_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/storage/postgres"
+ "github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/stretchr/testify/assert"
+)
+
+func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables.StateSnapshot, close func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, sqlutil.NewExclusiveWriter())
+ assert.NoError(t, err)
+ switch dbType {
+ case test.DBTypePostgres:
+ err = postgres.CreateStateSnapshotTable(db)
+ assert.NoError(t, err)
+ tab, err = postgres.PrepareStateSnapshotTable(db)
+ case test.DBTypeSQLite:
+ err = sqlite3.CreateStateSnapshotTable(db)
+ assert.NoError(t, err)
+ tab, err = sqlite3.PrepareStateSnapshotTable(db)
+ }
+ assert.NoError(t, err)
+
+ return tab, close
+}
+
+func TestStateSnapshotTable(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, close := mustCreateStateSnapshotTable(t, dbType)
+ defer close()
+
+ // generate some dummy data
+ var stateBlockNIDs types.StateBlockNIDs
+ for i := 0; i < 100; i++ {
+ stateBlockNIDs = append(stateBlockNIDs, types.StateBlockNID(i))
+ }
+ stateNID, err := tab.InsertState(ctx, nil, 1, stateBlockNIDs)
+ assert.NoError(t, err)
+ assert.Equal(t, types.StateSnapshotNID(1), stateNID)
+
+ // verify ON CONFLICT; Note: this updates the sequence!
+ stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs)
+ assert.NoError(t, err)
+ assert.Equal(t, types.StateSnapshotNID(1), stateNID)
+
+ // create a second snapshot
+ var stateBlockNIDs2 types.StateBlockNIDs
+ for i := 100; i < 150; i++ {
+ stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
+ }
+
+ stateNID, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
+ assert.NoError(t, err)
+ // StateSnapshotNID is now 3, since the DO UPDATE SET statement incremented the sequence
+ assert.Equal(t, types.StateSnapshotNID(3), stateNID)
+
+ nidLists, err := tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{1, 3})
+ assert.NoError(t, err)
+ assert.Equal(t, stateBlockNIDs, types.StateBlockNIDs(nidLists[0].StateBlockNIDs))
+ assert.Equal(t, stateBlockNIDs2, types.StateBlockNIDs(nidLists[1].StateBlockNIDs))
+
+ // check we get an error if the state snapshot does not exist
+ _, err = tab.BulkSelectStateBlockNIDs(ctx, nil, []types.StateSnapshotNID{2})
+ assert.Error(t, err)
+
+ // create a second snapshot
+ for i := 0; i < 65555; i++ {
+ stateBlockNIDs2 = append(stateBlockNIDs2, types.StateBlockNID(i))
+ }
+ _, err = tab.InsertState(ctx, nil, 1, stateBlockNIDs2)
+ assert.NoError(t, err)
+ })
+}
diff --git a/roomserver/types/types.go b/roomserver/types/types.go
index ce4e5fd1..62695aae 100644
--- a/roomserver/types/types.go
+++ b/roomserver/types/types.go
@@ -21,6 +21,7 @@ import (
"strings"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
"golang.org/x/crypto/blake2b"
)
@@ -97,6 +98,38 @@ func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
return a.EventStateKeyNID < b.EventStateKeyNID
}
+type StateKeyTupleSorter []StateKeyTuple
+
+func (s StateKeyTupleSorter) Len() int { return len(s) }
+func (s StateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
+func (s StateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
+// Check whether a tuple is in the list. Assumes that the list is sorted.
+func (s StateKeyTupleSorter) contains(value StateKeyTuple) bool {
+ i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
+ return i < len(s) && s[i] == value
+}
+
+// List the unique eventTypeNIDs and eventStateKeyNIDs.
+// Assumes that the list is sorted.
+func (s StateKeyTupleSorter) TypesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
+ eventTypeNIDs = make([]int64, len(s))
+ eventStateKeyNIDs = make([]int64, len(s))
+ for i := range s {
+ eventTypeNIDs[i] = int64(s[i].EventTypeNID)
+ eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
+ }
+ eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
+ eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
+ return
+}
+
+type int64Sorter []int64
+
+func (s int64Sorter) Len() int { return len(s) }
+func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
+func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
+
// A StateEntry is an entry in the room state of a matrix room.
type StateEntry struct {
StateKeyTuple
diff --git a/roomserver/types/types_test.go b/roomserver/types/types_test.go
index b1e84b82..a26b80f7 100644
--- a/roomserver/types/types_test.go
+++ b/roomserver/types/types_test.go
@@ -1,6 +1,7 @@
package types
import (
+ "sort"
"testing"
)
@@ -24,3 +25,66 @@ func TestDeduplicateStateEntries(t *testing.T) {
}
}
}
+
+func TestStateKeyTupleSorter(t *testing.T) {
+ input := StateKeyTupleSorter{
+ {EventTypeNID: 1, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 4},
+ {EventTypeNID: 2, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 1},
+ }
+ want := []StateKeyTuple{
+ {EventTypeNID: 1, EventStateKeyNID: 1},
+ {EventTypeNID: 1, EventStateKeyNID: 2},
+ {EventTypeNID: 1, EventStateKeyNID: 4},
+ {EventTypeNID: 2, EventStateKeyNID: 2},
+ }
+ doNotWant := []StateKeyTuple{
+ {EventTypeNID: 0, EventStateKeyNID: 0},
+ {EventTypeNID: 1, EventStateKeyNID: 3},
+ {EventTypeNID: 2, EventStateKeyNID: 1},
+ {EventTypeNID: 3, EventStateKeyNID: 1},
+ }
+ wantTypeNIDs := []int64{1, 2}
+ wantStateKeyNIDs := []int64{1, 2, 4}
+
+ // Sort the input and check it's in the right order.
+ sort.Sort(input)
+ gotTypeNIDs, gotStateKeyNIDs := input.TypesAndStateKeysAsArrays()
+
+ for i := range want {
+ if input[i] != want[i] {
+ t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
+ }
+
+ if !input.contains(want[i]) {
+ t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
+ }
+ }
+
+ for i := range doNotWant {
+ if input.contains(doNotWant[i]) {
+ t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
+ }
+ }
+
+ if len(wantTypeNIDs) != len(gotTypeNIDs) {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+
+ for i := range wantTypeNIDs {
+ if wantTypeNIDs[i] != gotTypeNIDs[i] {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+ }
+
+ if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
+ t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
+ }
+
+ for i := range wantStateKeyNIDs {
+ if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
+ t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
+ }
+ }
+}