aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-04-27 08:07:13 +0200
committerGitHub <noreply@github.com>2023-04-27 08:07:13 +0200
commit2475cf4b61747e76a524af6f71a4eb7e112812af (patch)
treec2446b71a0538fc340a7fb23e8a6c75a48b0a7dd /roomserver
parentdd5e47a9a75f717381c27adebdee18aa80a1f256 (diff)
Add some roomserver UTs (#3067)
Adds tests for `QueryRestrictedJoinAllowed`, `IsServerAllowed` and `PerformRoomUpgrade`. Refactors the `QueryRoomVersionForRoom` method to accept a string and return a `gmsl.RoomVersion` instead of req/resp structs. Adds some more caching for `GetStateEvent` This should also fix #2912 by ignoring state events belonging to other users.
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/api/api.go4
-rw-r--r--roomserver/auth/auth.go4
-rw-r--r--roomserver/auth/auth_test.go85
-rw-r--r--roomserver/internal/input/input_events.go2
-rw-r--r--roomserver/internal/perform/perform_upgrade.go12
-rw-r--r--roomserver/internal/query/query.go46
-rw-r--r--roomserver/roomserver_test.go512
-rw-r--r--roomserver/storage/shared/storage.go26
8 files changed, 642 insertions, 49 deletions
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index 4ce40e3e..d4bd73ab 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -143,7 +143,7 @@ type ClientRoomserverAPI interface {
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
- QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
+ QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
QueryPublishedRooms(ctx context.Context, req *QueryPublishedRoomsRequest, res *QueryPublishedRoomsResponse) error
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
@@ -183,7 +183,7 @@ type FederationRoomserverAPI interface {
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
- QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error
+ QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go
index 5f72454a..e872dcc3 100644
--- a/roomserver/auth/auth.go
+++ b/roomserver/auth/auth.go
@@ -26,6 +26,10 @@ func IsServerAllowed(
serverCurrentlyInRoom bool,
authEvents []*gomatrixserverlib.Event,
) bool {
+ // In practice should not happen, but avoids unneeded CPU cycles
+ if serverName == "" || len(authEvents) == 0 {
+ return false
+ }
historyVisibility := HistoryVisibilityForRoom(authEvents)
// 1. If the history_visibility was set to world_readable, allow.
diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go
new file mode 100644
index 00000000..7478b924
--- /dev/null
+++ b/roomserver/auth/auth_test.go
@@ -0,0 +1,85 @@
+package auth
+
+import (
+ "testing"
+
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
+)
+
+func TestIsServerAllowed(t *testing.T) {
+ alice := test.NewUser(t)
+
+ tests := []struct {
+ name string
+ want bool
+ roomFunc func() *test.Room
+ serverName spec.ServerName
+ serverCurrentlyInRoom bool
+ }{
+ {
+ name: "no servername specified",
+ roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
+ },
+ {
+ name: "no authEvents specified",
+ serverName: "test",
+ roomFunc: func() *test.Room { return &test.Room{} },
+ },
+ {
+ name: "default denied",
+ serverName: "test2",
+ roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
+ },
+ {
+ name: "world readable room",
+ serverName: "test",
+ roomFunc: func() *test.Room {
+ return test.NewRoom(t, alice, test.RoomHistoryVisibility(gomatrixserverlib.HistoryVisibilityWorldReadable))
+ },
+ want: true,
+ },
+ {
+ name: "allowed due to alice being joined",
+ serverName: "test",
+ roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
+ want: true,
+ },
+ {
+ name: "allowed due to 'serverCurrentlyInRoom'",
+ serverName: "test2",
+ roomFunc: func() *test.Room { return test.NewRoom(t, alice) },
+ want: true,
+ serverCurrentlyInRoom: true,
+ },
+ {
+ name: "allowed due to pending invite",
+ serverName: "test2",
+ roomFunc: func() *test.Room {
+ bob := test.User{ID: "@bob:test2"}
+ r := test.NewRoom(t, alice, test.RoomHistoryVisibility(gomatrixserverlib.HistoryVisibilityInvited))
+ r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
+ "membership": spec.Invite,
+ }, test.WithStateKey(bob.ID))
+ return r
+ },
+ want: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.roomFunc == nil {
+ t.Fatalf("missing roomFunc")
+ }
+ var authEvents []*gomatrixserverlib.Event
+ for _, ev := range tt.roomFunc().Events() {
+ authEvents = append(authEvents, ev.Event)
+ }
+
+ if got := IsServerAllowed(tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
+ t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 334e68b9..34566572 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -478,7 +478,7 @@ func (r *Inputer) processRoomEvent(
// If guest_access changed and is not can_join, kick all guest users.
if event.Type() == spec.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" {
- if err = r.kickGuests(ctx, event, roomInfo); err != nil {
+ if err = r.kickGuests(ctx, event, roomInfo); err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation")
}
}
diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go
index ed57abf2..644f7fda 100644
--- a/roomserver/internal/perform/perform_upgrade.go
+++ b/roomserver/internal/perform/perform_upgrade.go
@@ -319,9 +319,7 @@ func publishNewRoomAndUnpublishOldRoom(
}
func (r *Upgrader) validateRoomExists(ctx context.Context, roomID string) error {
- verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
- verRes := api.QueryRoomVersionForRoomResponse{}
- if err := r.URSAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
+ if _, err := r.URSAPI.QueryRoomVersionForRoom(ctx, roomID); err != nil {
return &api.PerformError{
Code: api.PerformErrorNoRoom,
Msg: "Room does not exist",
@@ -357,7 +355,7 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
continue
}
if event.Type() == spec.MRoomMember && !event.StateKeyEquals(userID) {
- // With the exception of bans and invites which we do want to copy, we
+ // With the exception of bans which we do want to copy, we
// should ignore membership events that aren't our own, as event auth will
// prevent us from being able to create membership events on behalf of other
// users anyway unless they are invites or bans.
@@ -367,11 +365,15 @@ func (r *Upgrader) generateInitialEvents(ctx context.Context, oldRoom *api.Query
}
switch membership {
case spec.Ban:
- case spec.Invite:
default:
continue
}
}
+ // skip events that rely on a specific user being present
+ sKey := *event.StateKey()
+ if event.Type() != spec.MRoomMember && len(sKey) > 0 && sKey[:1] == "@" {
+ continue
+ }
state[gomatrixserverlib.StateKeyTuple{EventType: event.Type(), StateKey: *event.StateKey()}] = event
}
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 8a5a9966..6c515dcc 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -521,14 +521,10 @@ func (r *Queryer) QueryMissingEvents(
response.Events = make([]*gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
for _, event := range loadedEvents {
if !eventsToFilter[event.EventID()] {
- roomVersion, verr := r.roomVersion(event.RoomID())
- if verr != nil {
- return verr
- }
if _, ok := redactEventIDs[event.EventID()]; ok {
event.Redact()
}
- response.Events = append(response.Events, event.Headered(roomVersion))
+ response.Events = append(response.Events, event.Headered(info.RoomVersion))
}
}
@@ -696,34 +692,20 @@ func GetAuthChain(
}
// QueryRoomVersionForRoom implements api.RoomserverInternalAPI
-func (r *Queryer) QueryRoomVersionForRoom(
- ctx context.Context,
- request *api.QueryRoomVersionForRoomRequest,
- response *api.QueryRoomVersionForRoomResponse,
-) error {
- if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok {
- response.RoomVersion = roomVersion
- return nil
+func (r *Queryer) QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) {
+ if roomVersion, ok := r.Cache.GetRoomVersion(roomID); ok {
+ return roomVersion, nil
}
- info, err := r.DB.RoomInfo(ctx, request.RoomID)
+ info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
- return err
+ return "", err
}
if info == nil {
- return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID)
+ return "", fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", roomID)
}
- response.RoomVersion = info.RoomVersion
- r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
- return nil
-}
-
-func (r *Queryer) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) {
- var res api.QueryRoomVersionForRoomResponse
- err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{
- RoomID: roomID,
- }, &res)
- return res.RoomVersion, err
+ r.Cache.StoreRoomVersion(roomID, info.RoomVersion)
+ return info.RoomVersion, nil
}
func (r *Queryer) QueryPublishedRooms(
@@ -910,8 +892,8 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
if err = json.Unmarshal(joinRulesEvent.Content(), &joinRules); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
- // If the join rule isn't "restricted" then there's nothing more to do.
- res.Restricted = joinRules.JoinRule == spec.Restricted
+ // If the join rule isn't "restricted" or "knock_restricted" then there's nothing more to do.
+ res.Restricted = joinRules.JoinRule == spec.Restricted || joinRules.JoinRule == spec.KnockRestricted
if !res.Restricted {
return nil
}
@@ -932,9 +914,9 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
if err != nil {
return fmt.Errorf("r.DB.GetStateEvent: %w", err)
}
- var powerLevels gomatrixserverlib.PowerLevelContent
- if err = json.Unmarshal(powerLevelsEvent.Content(), &powerLevels); err != nil {
- return fmt.Errorf("json.Unmarshal: %w", err)
+ powerLevels, err := powerLevelsEvent.PowerLevels()
+ if err != nil {
+ return fmt.Errorf("unable to get powerlevels: %w", err)
}
// Step through the join rules and see if the user matches any of them.
for _, rule := range joinRules.Allow {
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index 67d6db46..d1a74d3c 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -8,10 +8,13 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/caching"
+ "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
+ "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types"
@@ -582,3 +585,512 @@ func TestRedaction(t *testing.T) {
}
})
}
+
+func TestQueryRestrictedJoinAllowed(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+
+ // a room we don't create in the database
+ allowedByRoomNotExists := test.NewRoom(t, alice)
+
+ // a room we create in the database, used for authorisation
+ allowedByRoomExists := test.NewRoom(t, alice)
+ allowedByRoomExists.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
+ "membership": spec.Join,
+ }, test.WithStateKey(bob.ID))
+
+ testCases := []struct {
+ name string
+ prepareRoomFunc func(t *testing.T) *test.Room
+ wantResponse api.QueryRestrictedJoinAllowedResponse
+ }{
+ {
+ name: "public room unrestricted",
+ prepareRoomFunc: func(t *testing.T) *test.Room {
+ return test.NewRoom(t, alice)
+ },
+ wantResponse: api.QueryRestrictedJoinAllowedResponse{
+ Resident: true,
+ },
+ },
+ {
+ name: "room version without restrictions",
+ prepareRoomFunc: func(t *testing.T) *test.Room {
+ return test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV7))
+ },
+ },
+ {
+ name: "restricted only", // bob is not allowed to join
+ prepareRoomFunc: func(t *testing.T) *test.Room {
+ r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8))
+ r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
+ "join_rule": spec.Restricted,
+ }, test.WithStateKey(""))
+ return r
+ },
+ wantResponse: api.QueryRestrictedJoinAllowedResponse{
+ Resident: true,
+ Restricted: true,
+ },
+ },
+ {
+ name: "knock_restricted",
+ prepareRoomFunc: func(t *testing.T) *test.Room {
+ r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8))
+ r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
+ "join_rule": spec.KnockRestricted,
+ }, test.WithStateKey(""))
+ return r
+ },
+ wantResponse: api.QueryRestrictedJoinAllowedResponse{
+ Resident: true,
+ Restricted: true,
+ },
+ },
+ {
+ name: "restricted with pending invite", // bob should be allowed to join
+ prepareRoomFunc: func(t *testing.T) *test.Room {
+ r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV8))
+ r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
+ "join_rule": spec.Restricted,
+ }, test.WithStateKey(""))
+ r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
+ "membership": spec.Invite,
+ }, test.WithStateKey(bob.ID))
+ return r
+ },
+ wantResponse: api.QueryRestrictedJoinAllowedResponse{
+ Resident: true,
+ Restricted: true,
+ Allowed: true,
+ },
+ },
+ {
+ name: "restricted with allowed room_id, but missing room", // bob should not be allowed to join, as we don't know about the room
+ prepareRoomFunc: func(t *testing.T) *test.Room {
+ r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV10))
+ r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
+ "join_rule": spec.KnockRestricted,
+ "allow": []map[string]interface{}{
+ {
+ "room_id": allowedByRoomNotExists.ID,
+ "type": spec.MRoomMembership,
+ },
+ },
+ }, test.WithStateKey(""))
+ r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
+ "membership": spec.Join,
+ "join_authorised_via_users_server": alice.ID,
+ }, test.WithStateKey(bob.ID))
+ return r
+ },
+ wantResponse: api.QueryRestrictedJoinAllowedResponse{
+ Restricted: true,
+ },
+ },
+ {
+ name: "restricted with allowed room_id", // bob should be allowed to join, as we know about the room
+ prepareRoomFunc: func(t *testing.T) *test.Room {
+ r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV10))
+ r.CreateAndInsert(t, alice, spec.MRoomJoinRules, map[string]interface{}{
+ "join_rule": spec.KnockRestricted,
+ "allow": []map[string]interface{}{
+ {
+ "room_id": allowedByRoomExists.ID,
+ "type": spec.MRoomMembership,
+ },
+ },
+ }, test.WithStateKey(""))
+ r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{
+ "membership": spec.Join,
+ "join_authorised_via_users_server": alice.ID,
+ }, test.WithStateKey(bob.ID))
+ return r
+ },
+ wantResponse: api.QueryRestrictedJoinAllowedResponse{
+ Resident: true,
+ Restricted: true,
+ Allowed: true,
+ AuthorisedVia: alice.ID,
+ },
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ natsInstance := jetstream.NATSInstance{}
+ defer close()
+
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ rsAPI.SetFederationAPI(nil, nil)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if tc.prepareRoomFunc == nil {
+ t.Fatal("missing prepareRoomFunc")
+ }
+ testRoom := tc.prepareRoomFunc(t)
+ // Create the room
+ if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, testRoom.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+
+ if err := api.SendEvents(processCtx.Context(), rsAPI, api.KindNew, allowedByRoomExists.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+
+ req := api.QueryRestrictedJoinAllowedRequest{
+ UserID: bob.ID,
+ RoomID: testRoom.ID,
+ }
+ res := api.QueryRestrictedJoinAllowedResponse{}
+ if err := rsAPI.QueryRestrictedJoinAllowed(processCtx.Context(), &req, &res); err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(tc.wantResponse, res) {
+ t.Fatalf("unexpected response, want %#v - got %#v", tc.wantResponse, res)
+ }
+ })
+ }
+ })
+}
+
+func TestUpgrade(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ charlie := test.NewUser(t)
+ ctx := context.Background()
+
+ spaceChild := test.NewRoom(t, alice)
+ validateTuples := []gomatrixserverlib.StateKeyTuple{
+ {EventType: spec.MRoomCreate},
+ {EventType: spec.MRoomPowerLevels},
+ {EventType: spec.MRoomJoinRules},
+ {EventType: spec.MRoomName},
+ {EventType: spec.MRoomCanonicalAlias},
+ {EventType: "m.room.tombstone"},
+ {EventType: "m.custom.event"},
+ {EventType: "m.space.child", StateKey: spaceChild.ID},
+ {EventType: "m.custom.event", StateKey: alice.ID},
+ {EventType: spec.MRoomMember, StateKey: charlie.ID}, // ban should be transferred
+ }
+
+ validate := func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) {
+
+ oldRoomState := &api.QueryCurrentStateResponse{}
+ if err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{
+ RoomID: oldRoomID,
+ StateTuples: validateTuples,
+ }, oldRoomState); err != nil {
+ t.Fatal(err)
+ }
+
+ newRoomState := &api.QueryCurrentStateResponse{}
+ if err := rsAPI.QueryCurrentState(ctx, &api.QueryCurrentStateRequest{
+ RoomID: newRoomID,
+ StateTuples: validateTuples,
+ }, newRoomState); err != nil {
+ t.Fatal(err)
+ }
+
+ // the old room should have a tombstone event
+ ev := oldRoomState.StateEvents[gomatrixserverlib.StateKeyTuple{EventType: "m.room.tombstone"}]
+ replacementRoom := gjson.GetBytes(ev.Content(), "replacement_room").Str
+ if replacementRoom != newRoomID {
+ t.Fatalf("tombstone event has replacement_room '%s', expected '%s'", replacementRoom, newRoomID)
+ }
+
+ // the new room should have a predecessor equal to the old room
+ ev = newRoomState.StateEvents[gomatrixserverlib.StateKeyTuple{EventType: spec.MRoomCreate}]
+ predecessor := gjson.GetBytes(ev.Content(), "predecessor.room_id").Str
+ if predecessor != oldRoomID {
+ t.Fatalf("got predecessor room '%s', expected '%s'", predecessor, oldRoomID)
+ }
+
+ for _, tuple := range validateTuples {
+ // Skip create and powerlevel event (new room has e.g. predecessor event, old room has restricted powerlevels)
+ switch tuple.EventType {
+ case spec.MRoomCreate, spec.MRoomPowerLevels, spec.MRoomCanonicalAlias:
+ continue
+ }
+ oldEv, ok := oldRoomState.StateEvents[tuple]
+ if !ok {
+ t.Logf("skipping tuple %#v as it doesn't exist in the old room", tuple)
+ continue
+ }
+ newEv, ok := newRoomState.StateEvents[tuple]
+ if !ok {
+ t.Logf("skipping tuple %#v as it doesn't exist in the new room", tuple)
+ continue
+ }
+
+ if !reflect.DeepEqual(oldEv.Content(), newEv.Content()) {
+ t.Logf("OldEvent QueryCurrentState: %s", string(oldEv.Content()))
+ t.Logf("NewEvent QueryCurrentState: %s", string(newEv.Content()))
+ t.Errorf("event content mismatch")
+ }
+ }
+ }
+
+ testCases := []struct {
+ name string
+ upgradeUser string
+ roomFunc func(rsAPI api.RoomserverInternalAPI) string
+ validateFunc func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI)
+ wantNewRoom bool
+ }{
+ {
+ name: "invalid userID",
+ upgradeUser: "!notvalid:test",
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ room := test.NewRoom(t, alice)
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+ return room.ID
+ },
+ },
+ {
+ name: "invalid roomID",
+ upgradeUser: alice.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ return "!doesnotexist:test"
+ },
+ },
+ {
+ name: "powerlevel too low",
+ upgradeUser: bob.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ room := test.NewRoom(t, alice)
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+ return room.ID
+ },
+ },
+ {
+ name: "successful upgrade on new room",
+ upgradeUser: alice.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ room := test.NewRoom(t, alice)
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+ return room.ID
+ },
+ wantNewRoom: true,
+ validateFunc: validate,
+ },
+ {
+ name: "successful upgrade on new room with other state events",
+ upgradeUser: alice.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ r := test.NewRoom(t, alice)
+ r.CreateAndInsert(t, alice, spec.MRoomName, map[string]interface{}{
+ "name": "my new name",
+ }, test.WithStateKey(""))
+ r.CreateAndInsert(t, alice, spec.MRoomCanonicalAlias, eventutil.CanonicalAliasContent{
+ Alias: "#myalias:test",
+ }, test.WithStateKey(""))
+
+ // this will be transferred
+ r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{
+ "random": "i should exist",
+ }, test.WithStateKey(""))
+
+ // the following will be ignored
+ r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{
+ "random": "i will be ignored",
+ }, test.WithStateKey(alice.ID))
+
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+ return r.ID
+ },
+ wantNewRoom: true,
+ validateFunc: validate,
+ },
+ {
+ name: "with published room",
+ upgradeUser: alice.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ r := test.NewRoom(t, alice)
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+
+ if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{
+ RoomID: r.ID,
+ Visibility: spec.Public,
+ }, &api.PerformPublishResponse{}); err != nil {
+ t.Fatal(err)
+ }
+
+ return r.ID
+ },
+ wantNewRoom: true,
+ validateFunc: func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) {
+ validate(t, oldRoomID, newRoomID, rsAPI)
+ // check that the new room is published
+ res := &api.QueryPublishedRoomsResponse{}
+ if err := rsAPI.QueryPublishedRooms(ctx, &api.QueryPublishedRoomsRequest{RoomID: newRoomID}, res); err != nil {
+ t.Fatal(err)
+ }
+ if len(res.RoomIDs) == 0 {
+ t.Fatalf("expected room to be published, but wasn't: %#v", res.RoomIDs)
+ }
+ },
+ },
+ {
+ name: "with alias",
+ upgradeUser: alice.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ r := test.NewRoom(t, alice)
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+
+ if err := rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{
+ RoomID: r.ID,
+ Alias: "#myroomalias:test",
+ }, &api.SetRoomAliasResponse{}); err != nil {
+ t.Fatal(err)
+ }
+
+ return r.ID
+ },
+ wantNewRoom: true,
+ validateFunc: func(t *testing.T, oldRoomID, newRoomID string, rsAPI api.RoomserverInternalAPI) {
+ validate(t, oldRoomID, newRoomID, rsAPI)
+ // check that the old room has no aliases
+ res := &api.GetAliasesForRoomIDResponse{}
+ if err := rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: oldRoomID}, res); err != nil {
+ t.Fatal(err)
+ }
+ if len(res.Aliases) != 0 {
+ t.Fatalf("expected old room aliases to be empty, but wasn't: %#v", res.Aliases)
+ }
+
+ // check that the new room has aliases
+ if err := rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: newRoomID}, res); err != nil {
+ t.Fatal(err)
+ }
+ if len(res.Aliases) == 0 {
+ t.Fatalf("expected room aliases to be transferred, but wasn't: %#v", res.Aliases)
+ }
+ },
+ },
+ {
+ name: "bans are transferred",
+ upgradeUser: alice.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ r := test.NewRoom(t, alice)
+ r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{
+ "membership": spec.Ban,
+ }, test.WithStateKey(charlie.ID))
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+ return r.ID
+ },
+ wantNewRoom: true,
+ validateFunc: validate,
+ },
+ {
+ name: "space childs are transferred",
+ upgradeUser: alice.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ r := test.NewRoom(t, alice)
+
+ r.CreateAndInsert(t, alice, "m.space.child", map[string]interface{}{}, test.WithStateKey(spaceChild.ID))
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+ return r.ID
+ },
+ wantNewRoom: true,
+ validateFunc: validate,
+ },
+ {
+ name: "custom state is not taken to the new room", // https://github.com/matrix-org/dendrite/issues/2912
+ upgradeUser: charlie.ID,
+ roomFunc: func(rsAPI api.RoomserverInternalAPI) string {
+ r := test.NewRoom(t, alice, test.RoomVersion(gomatrixserverlib.RoomVersionV6))
+ // Bob and Charlie join
+ r.CreateAndInsert(t, bob, spec.MRoomMember, map[string]interface{}{"membership": spec.Join}, test.WithStateKey(bob.ID))
+ r.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]interface{}{"membership": spec.Join}, test.WithStateKey(charlie.ID))
+
+ // make Charlie an admin so the room can be upgraded
+ r.CreateAndInsert(t, alice, spec.MRoomPowerLevels, gomatrixserverlib.PowerLevelContent{
+ Users: map[string]int64{
+ charlie.ID: 100,
+ },
+ }, test.WithStateKey(""))
+
+ // Alice creates a custom event
+ r.CreateAndInsert(t, alice, "m.custom.event", map[string]interface{}{
+ "random": "data",
+ }, test.WithStateKey(alice.ID))
+ r.CreateAndInsert(t, alice, spec.MRoomMember, map[string]interface{}{"membership": spec.Leave}, test.WithStateKey(alice.ID))
+
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, r.Events(), "test", "test", "test", nil, false); err != nil {
+ t.Errorf("failed to send events: %v", err)
+ }
+ return r.ID
+ },
+ wantNewRoom: true,
+ validateFunc: validate,
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ natsInstance := jetstream.NATSInstance{}
+ defer close()
+
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
+ rsAPI.SetFederationAPI(nil, nil)
+ rsAPI.SetUserAPI(userAPI)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if tc.roomFunc == nil {
+ t.Fatalf("missing roomFunc")
+ }
+ if tc.upgradeUser == "" {
+ tc.upgradeUser = alice.ID
+ }
+ roomID := tc.roomFunc(rsAPI)
+
+ upgradeReq := api.PerformRoomUpgradeRequest{
+ RoomID: roomID,
+ UserID: tc.upgradeUser,
+ RoomVersion: version.DefaultRoomVersion(), // always upgrade to the latest version
+ }
+ upgradeRes := api.PerformRoomUpgradeResponse{}
+
+ if err := rsAPI.PerformRoomUpgrade(processCtx.Context(), &upgradeReq, &upgradeRes); err != nil {
+ t.Fatal(err)
+ }
+
+ if tc.wantNewRoom && upgradeRes.NewRoomID == "" {
+ t.Fatalf("expected a new room, but the upgrade failed")
+ }
+ if !tc.wantNewRoom && upgradeRes.NewRoomID != "" {
+ t.Fatalf("expected no new room, but the upgrade succeeded")
+ }
+ if tc.validateFunc != nil {
+ tc.validateFunc(t, roomID, upgradeRes.NewRoomID, rsAPI)
+ }
+ })
+ }
+ })
+}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 8db11644..b411a4cd 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -669,13 +669,17 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event *gomatrixserve
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return nil, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
}
- if roomVersion == "" {
- rv, ok := d.Cache.GetRoomVersion(event.RoomID())
- if ok {
- roomVersion = rv
- }
+
+ roomNID, nidOK := d.Cache.GetRoomServerRoomNID(event.RoomID())
+ cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(event.RoomID())
+ // if we found both, the roomNID and version in our cache, no need to query the database
+ if nidOK && versionOK {
+ return &types.RoomInfo{
+ RoomNID: roomNID,
+ RoomVersion: cachedRoomVersion,
+ }, nil
}
- var roomNID types.RoomNID
+
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
if err != nil {
@@ -1164,7 +1168,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if roomInfo.IsStub() {
return nil, nil
}
- eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
+ eventTypeNID, err := d.GetOrCreateEventTypeNID(ctx, evType)
if err == sql.ErrNoRows {
// No rooms have an event of this type, otherwise we'd have an event type NID
return nil, nil
@@ -1172,7 +1176,7 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if err != nil {
return nil, err
}
- stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
+ stateKeyNID, err := d.GetOrCreateEventStateKeyNID(ctx, &stateKey)
if err == sql.ErrNoRows {
// No rooms have a state event with this state key, otherwise we'd have an state key NID
return nil, nil
@@ -1201,6 +1205,10 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
// return the event requested
for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
+ cachedEvent, ok := d.Cache.GetRoomServerEvent(e.EventNID)
+ if ok {
+ return cachedEvent.Headered(roomInfo.RoomVersion), nil
+ }
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil {
return nil, err
@@ -1324,7 +1332,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
// isn't a failure.
- eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
+ eventTypeNIDMap, err := d.eventTypeNIDs(ctx, nil, eventTypes)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
}