aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-06-13 16:28:41 +0200
committerGitHub <noreply@github.com>2023-06-13 16:28:41 +0200
commit7a2e325d1014d76188b47a011730a42443f3c174 (patch)
tree4b676ac3a083633a2dc76b8a02247093825182d2 /roomserver
parent2c87972a3a84be400e5c69e2e5a727f21b4e457e (diff)
Add `AssignRoomNID` to pre-assign roomNIDs (#3111)
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/storage/interface.go2
-rw-r--r--roomserver/storage/shared/storage.go20
-rw-r--r--roomserver/storage/shared/storage_test.go22
3 files changed, 44 insertions, 0 deletions
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index ef446378..7787d9f8 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -31,6 +31,7 @@ type Database interface {
UserRoomKeys
// Do we support processing input events for more than one room at a time?
SupportsConcurrentRoomInputs() bool
+ AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
// RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
@@ -212,6 +213,7 @@ type UserRoomKeys interface {
type RoomDatabase interface {
EventDatabase
UserRoomKeys
+ AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
// RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
RoomInfoByNID(ctx context.Context, roomNID types.RoomNID) (*types.RoomInfo, error)
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index d7ca3cef..bda51da8 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -662,6 +662,26 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
}
+func (d *Database) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) {
+ // This should already be checked, let's check it anyway.
+ _, err = gomatrixserverlib.GetRoomVersion(roomVersion)
+ if err != nil {
+ return 0, err
+ }
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ roomNID, err = d.assignRoomNID(ctx, txn, roomID.String(), roomVersion)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ return 0, err
+ }
+ // Not setting caches, as assignRoomNID already does this
+ return roomNID, err
+}
+
// GetOrCreateRoomInfo gets or creates a new RoomInfo, which is only safe to use with functions only needing a roomVersion or roomNID.
func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (roomInfo *types.RoomInfo, err error) {
// Get the default room version. If the client doesn't supply a room_version
diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go
index 4fa451bc..581d83ee 100644
--- a/roomserver/storage/shared/storage_test.go
+++ b/roomserver/storage/shared/storage_test.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
@@ -199,3 +200,24 @@ func TestUserRoomKeys(t *testing.T) {
assert.Error(t, err)
})
}
+
+func TestAssignRoomNID(t *testing.T) {
+ ctx := context.Background()
+ alice := test.NewUser(t)
+ room := test.NewRoom(t, alice)
+
+ roomID, err := spec.NewRoomID(room.ID)
+ assert.NoError(t, err)
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, close := mustCreateRoomserverDatabase(t, dbType)
+ defer close()
+
+ nid, err := db.AssignRoomNID(ctx, *roomID, room.Version)
+ assert.NoError(t, err)
+ assert.Greater(t, nid, types.EventNID(0))
+
+ _, err = db.AssignRoomNID(ctx, spec.RoomID{}, "notaroomversion")
+ assert.Error(t, err)
+ })
+}