aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-06-14 14:23:46 +0000
committerGitHub <noreply@github.com>2023-06-14 14:23:46 +0000
commite4665979bfbe006368d55189f074e456fe19b198 (patch)
treee909d694a022478d0dbe3cc58ee8a2dc289bc969 /roomserver
parent7a2e325d1014d76188b47a011730a42443f3c174 (diff)
Merge SenderID & Per Room User Key work (#3109)
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/api/api.go8
-rw-r--r--roomserver/auth/auth.go16
-rw-r--r--roomserver/auth/auth_test.go10
-rw-r--r--roomserver/internal/alias.go12
-rw-r--r--roomserver/internal/api.go1
-rw-r--r--roomserver/internal/helpers/auth.go8
-rw-r--r--roomserver/internal/helpers/helpers.go38
-rw-r--r--roomserver/internal/input/input_events.go35
-rw-r--r--roomserver/internal/input/input_events_test.go2
-rw-r--r--roomserver/internal/input/input_latest_events.go2
-rw-r--r--roomserver/internal/input/input_membership.go6
-rw-r--r--roomserver/internal/input/input_missing.go22
-rw-r--r--roomserver/internal/perform/perform_admin.go20
-rw-r--r--roomserver/internal/perform/perform_backfill.go38
-rw-r--r--roomserver/internal/perform/perform_create_room.go37
-rw-r--r--roomserver/internal/perform/perform_inbound_peek.go2
-rw-r--r--roomserver/internal/perform/perform_invite.go31
-rw-r--r--roomserver/internal/perform/perform_join.go153
-rw-r--r--roomserver/internal/perform/perform_leave.go10
-rw-r--r--roomserver/internal/perform/perform_upgrade.go10
-rw-r--r--roomserver/internal/query/query.go77
-rw-r--r--roomserver/roomserver_test.go5
-rw-r--r--roomserver/state/state.go14
-rw-r--r--roomserver/storage/interface.go10
-rw-r--r--roomserver/storage/postgres/user_room_keys_table.go19
-rw-r--r--roomserver/storage/shared/room_updater.go5
-rw-r--r--roomserver/storage/shared/storage.go55
-rw-r--r--roomserver/storage/shared/storage_test.go7
-rw-r--r--roomserver/storage/sqlite3/user_room_keys_table.go19
-rw-r--r--roomserver/storage/tables/interface.go2
-rw-r--r--roomserver/storage/tables/user_room_keys_table_test.go7
31 files changed, 441 insertions, 240 deletions
diff --git a/roomserver/api/api.go b/roomserver/api/api.go
index fec28841..e2dd5dd7 100644
--- a/roomserver/api/api.go
+++ b/roomserver/api/api.go
@@ -51,6 +51,7 @@ type RoomserverInternalAPI interface {
UserRoomserverAPI
FederationRoomserverAPI
QuerySenderIDAPI
+ UserRoomPrivateKeyCreator
// needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs
@@ -67,7 +68,9 @@ type RoomserverInternalAPI interface {
req *QueryAuthChainRequest,
res *QueryAuthChainResponse,
) error
+}
+type UserRoomPrivateKeyCreator interface {
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
}
@@ -81,8 +84,8 @@ type InputRoomEventsAPI interface {
}
type QuerySenderIDAPI interface {
- QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
- QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
+ QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
+ QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error)
}
// Query the latest events and state for a room from the room server.
@@ -228,6 +231,7 @@ type FederationRoomserverAPI interface {
QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI
QuerySenderIDAPI
+ UserRoomPrivateKeyCreator
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go
index ba10a433..d6c10cf9 100644
--- a/roomserver/auth/auth.go
+++ b/roomserver/auth/auth.go
@@ -15,7 +15,7 @@ package auth
import (
"context"
- "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
@@ -25,7 +25,7 @@ import (
// IsServerAllowed returns true if the server is allowed to see events in the room
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
func IsServerAllowed(
- ctx context.Context, db storage.RoomDatabase,
+ ctx context.Context, querier api.QuerySenderIDAPI,
serverName spec.ServerName,
serverCurrentlyInRoom bool,
authEvents []gomatrixserverlib.PDU,
@@ -41,7 +41,7 @@ func IsServerAllowed(
return true
}
// 2. If the user's membership was join, allow.
- joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join)
+ joinedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Join)
if joinedUserExists {
return true
}
@@ -50,7 +50,7 @@ func IsServerAllowed(
return true
}
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
- invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite)
+ invitedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Invite)
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
return true
}
@@ -74,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver
return visibility
}
-func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
+func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
for _, ev := range authEvents {
if ev.Type() != spec.MRoomMember {
continue
@@ -89,7 +89,11 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabas
continue
}
- userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
+ validRoomID, err := spec.NewRoomID(ev.RoomID())
+ if err != nil {
+ continue
+ }
+ userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
if err != nil {
continue
}
diff --git a/roomserver/auth/auth_test.go b/roomserver/auth/auth_test.go
index 192d9e5d..058361e6 100644
--- a/roomserver/auth/auth_test.go
+++ b/roomserver/auth/auth_test.go
@@ -4,17 +4,17 @@ import (
"context"
"testing"
- "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
-type FakeStorageDB struct {
- storage.RoomDatabase
+type FakeQuerier struct {
+ api.QuerySenderIDAPI
}
-func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
@@ -87,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) {
authEvents = append(authEvents, ev.PDU)
}
- if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
+ if got := IsServerAllowed(context.Background(), &FakeQuerier{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
}
})
diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go
index c950024a..e6fb7338 100644
--- a/roomserver/internal/alias.go
+++ b/roomserver/internal/alias.go
@@ -113,6 +113,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID(
return nil
}
+// nolint:gocyclo
// RemoveRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) RemoveRoomAlias(
ctx context.Context,
@@ -129,7 +130,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return nil
}
- sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID)
+ validRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return err
+ }
+
+ sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID)
if err != nil || sender == nil {
return fmt.Errorf("r.QueryUserIDForSender: %w", err)
}
@@ -177,7 +183,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
if request.SenderID != ev.SenderID() {
senderID = ev.SenderID()
}
- sender, err := r.QueryUserIDForSender(ctx, roomID, senderID)
+ sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
if err != nil || sender == nil {
return err
}
@@ -206,7 +212,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
}
stateRes := &api.QueryLatestEventsAndStateResponse{}
- if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
+ if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
return err
}
diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go
index 4bcd3f3e..7943ae5c 100644
--- a/roomserver/internal/api.go
+++ b/roomserver/internal/api.go
@@ -177,6 +177,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
IsLocalServerName: r.Cfg.Global.IsLocalServerName,
DB: r.DB,
FSAPI: r.fsAPI,
+ Querier: r.Queryer,
KeyRing: r.KeyRing,
// Perspective servers are trusted to not lie about server keys, so we will also
// prefer these servers when backfilling (assuming they are in the room) rather
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index 7782d07d..89fae244 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -22,6 +22,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
@@ -36,6 +37,7 @@ func CheckForSoftFail(
roomInfo *types.RoomInfo,
event *types.HeaderedEvent,
stateEventIDs []string,
+ querier api.QuerySenderIDAPI,
) (bool, error) {
rewritesState := len(stateEventIDs) > 1
@@ -49,7 +51,7 @@ func CheckForSoftFail(
} else {
// Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now.
- roomState := state.NewStateResolution(db, roomInfo)
+ roomState := state.NewStateResolution(db, roomInfo, querier)
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
if err != nil {
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
@@ -76,8 +78,8 @@ func CheckForSoftFail(
}
// Check if the event is allowed.
- if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return db.GetUserIDForSender(ctx, roomID, senderID)
+ if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return querier.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
// return true, nil
return true, err
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index 263cb9f8..febabf41 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -68,7 +68,7 @@ func UpdateToInviteMembership(
// memberships. If the servername is not supplied then the local server will be
// checked instead using a faster code path.
// TODO: This should probably be replaced by an API call.
-func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) {
+func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) {
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
return false, err
@@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
for i := range events {
gmslEvents[i] = events[i].PDU
}
- return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
+ return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil
}
func IsInvitePending(
@@ -211,8 +211,8 @@ func GetMembershipsAtState(
return events, nil
}
-func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
- roomState := state.NewStateResolution(db, info)
+func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) {
+ roomState := state.NewStateResolution(db, info, querier)
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
@@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
-func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
- roomState := state.NewStateResolution(db, info)
+func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) {
+ roomState := state.NewStateResolution(db, info, querier)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}
@@ -264,7 +264,7 @@ func LoadStateEvents(
}
func CheckServerAllowedToSeeEvent(
- ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI,
) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err {
@@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using
// the old and slow method
- stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
+ stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier)
if err != nil {
return false, err
}
@@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent(
return false, err
}
}
- return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
+ return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil
}
func slowGetHistoryVisibilityState(
- ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]gomatrixserverlib.PDU, error) {
- roomState := state.NewStateResolution(db, info)
+ roomState := state.NewStateResolution(db, info, querier)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState(
// If the event state key doesn't match the given servername
// then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc.
+ validRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return nil, err
+ }
for nid, key := range stateKeys {
if key != "" {
- userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
+ userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key))
if err == nil && userID != nil {
if userID.Domain() != serverName {
delete(stateKeys, nid)
@@ -349,7 +353,7 @@ func slowGetHistoryVisibilityState(
// TODO: Remove this when we have tests to assert correctness of this function
func ScanEventTree(
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
- serverName spec.ServerName,
+ serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]types.EventNID, map[string]struct{}, error) {
var resultNIDs []types.EventNID
var err error
@@ -392,7 +396,7 @@ BFSLoop:
// It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!)
ev := events[0]
- isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
+ isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
}
@@ -415,7 +419,7 @@ BFSLoop:
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
- allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
+ allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
@@ -444,7 +448,7 @@ BFSLoop:
}
func QueryLatestEventsAndState(
- ctx context.Context, db storage.Database,
+ ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI,
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
@@ -457,7 +461,7 @@ func QueryLatestEventsAndState(
return nil
}
- roomState := state.NewStateResolution(db, roomInfo)
+ roomState := state.NewStateResolution(db, roomInfo, querier)
response.RoomExists = true
response.RoomVersion = roomInfo.RoomVersion
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 7bb40163..aa05d959 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -128,7 +128,11 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
}
- sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ validRoomID, err := spec.NewRoomID(event.RoomID())
+ if err != nil {
+ return err
+ }
+ sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil {
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
}
@@ -282,8 +286,8 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted.
- if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
isRejected = true
rejectionErr = err
@@ -321,7 +325,7 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindNew && !isCreateEvent {
// Check that the event passes authentication checks based on the
// current room state.
- softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs)
+ softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs, r.Queryer)
if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event")
}
@@ -401,7 +405,7 @@ func (r *Inputer) processRoomEvent(
redactedEvent gomatrixserverlib.PDU
)
if !isRejected && !isCreateEvent {
- resolver := state.NewStateResolution(r.DB, roomInfo)
+ resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver)
if err != nil {
return err
@@ -587,8 +591,8 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent),
)
- if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
return
@@ -700,8 +704,8 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway.
- if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue nextAuthEvent
}
@@ -718,8 +722,8 @@ nextAuthEvent:
}
// Check if the auth event should be rejected.
- err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
})
if isRejected = err != nil; isRejected {
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
@@ -783,7 +787,7 @@ func (r *Inputer) calculateAndSetState(
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
- roomState := state.NewStateResolution(updater, roomInfo)
+ roomState := state.NewStateResolution(updater, roomInfo, r.Queryer)
if input.HasState {
// We've been told what the state at the event is so we don't need to calculate it.
@@ -836,13 +840,18 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
return err
}
+ validRoomID, err := spec.NewRoomID(event.RoomID())
+ if err != nil {
+ return err
+ }
+
prevEvents := latestRes.LatestEvents
for _, memberEvent := range memberEvents {
if memberEvent.StateKey() == nil {
continue
}
- memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey()))
+ memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey()))
if err != nil {
continue
}
diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go
index 5f2cd956..4ee6d211 100644
--- a/roomserver/internal/input/input_events_test.go
+++ b/roomserver/internal/input/input_events_test.go
@@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
}
// Finally check that the event is NOT allowed
- if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}); err == nil {
t.Fatalf("event should not be allowed, but it was")
diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go
index 7a7a021a..940783e0 100644
--- a/roomserver/internal/input/input_latest_events.go
+++ b/roomserver/internal/input/input_latest_events.go
@@ -213,7 +213,7 @@ func (u *latestEventsUpdater) latestState() error {
defer trace.EndRegion()
var err error
- roomState := state.NewStateResolution(u.updater, u.roomInfo)
+ roomState := state.NewStateResolution(u.updater, u.roomInfo, u.api.Queryer)
// Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
index 09c65dfe..c46f8dba 100644
--- a/roomserver/internal/input/input_membership.go
+++ b/roomserver/internal/input/input_membership.go
@@ -139,7 +139,11 @@ func (r *Inputer) updateMembership(
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil {
- userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
+ validRoomID, err := spec.NewRoomID(event.RoomID())
+ if err != nil {
+ return isTargetLocalUser
+ }
+ userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey))
if err != nil || userID == nil {
return isTargetLocalUser
}
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index f0f974d2..7ee84e4c 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -383,7 +383,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
defer trace.EndRegion()
var res parsedRespState
- roomState := state.NewStateResolution(t.db, t.roomInfo)
+ roomState := state.NewStateResolution(t.db, t.roomInfo, t.inputer.Queryer)
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
if err != nil {
t.log.WithError(err).Warnf("failed to get state after %s locally", eventID)
@@ -473,8 +473,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...)
}
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
- roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
},
)
if err != nil {
@@ -482,8 +482,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
}
// apply the current event
retryAllowedState:
- if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError:
@@ -569,8 +569,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue
}
@@ -660,8 +660,8 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(),
- }, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ }, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
return nil, err
@@ -897,8 +897,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
}
- if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err}
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index ec13bff8..12b557f5 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom(
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
return nil, err
}
+ validRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return nil, err
+ }
prevEvents := latestRes.LatestEvents
var senderDomain spec.ServerName
@@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
PrevEvents: prevEvents,
}
- userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID))
+ userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID))
if err != nil || userID == nil {
continue
}
@@ -264,16 +268,16 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
}
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue
}
authEventMap[authEvent.EventID()] = authEvent
}
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
continue
}
@@ -293,7 +297,11 @@ func (r *Admin) PerformAdminDownloadState(
stateIDs = append(stateIDs, stateEvent.EventID())
}
- senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID)
+ validRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return err
+ }
+ senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil {
return err
}
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index 8e87359a..533ad25b 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -42,6 +42,7 @@ type Backfiller struct {
DB storage.Database
FSAPI federationAPI.RoomserverFederationAPI
KeyRing gomatrixserverlib.JSONVerifier
+ Querier api.QuerySenderIDAPI
// The servers which should be preferred above other servers when backfilling
PreferServers []spec.ServerName
@@ -79,7 +80,7 @@ func (r *Backfiller) PerformBackfill(
}
// Scan the event tree for events to send back.
- resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
+ resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier)
if err != nil {
return err
}
@@ -113,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
if info == nil || info.IsStub() {
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
}
- requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
+ requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
// Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
@@ -121,8 +122,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester,
- r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
},
)
// Only return an error if we really couldn't get any events.
@@ -135,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
// persist these new events - auth checks have already been done
- roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
+ roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events)
for _, ev := range backfilledEventMap {
// now add state for these events
@@ -212,8 +213,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue
}
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
- result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
logger.WithError(err).Warn("failed to load and verify event")
@@ -246,13 +247,14 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
}
}
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
- persistEvents(ctx, r.DB, newEvents)
+ persistEvents(ctx, r.DB, r.Querier, newEvents)
}
// backfillRequester implements gomatrixserverlib.BackfillRequester
type backfillRequester struct {
db storage.Database
fsAPI federationAPI.RoomserverFederationAPI
+ querier api.QuerySenderIDAPI
virtualHost spec.ServerName
isLocalServerName func(spec.ServerName) bool
preferServer map[spec.ServerName]bool
@@ -268,6 +270,7 @@ type backfillRequester struct {
func newBackfillRequester(
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
+ querier api.QuerySenderIDAPI,
virtualHost spec.ServerName,
isLocalServerName func(spec.ServerName) bool,
bwExtrems map[string][]string, preferServers []spec.ServerName,
@@ -279,6 +282,7 @@ func newBackfillRequester(
return &backfillRequester{
db: db,
fsAPI: fsAPI,
+ querier: querier,
virtualHost: virtualHost,
isLocalServerName: isLocalServerName,
eventIDToBeforeStateIDs: make(map[string][]string),
@@ -460,14 +464,14 @@ FindSuccessor:
return nil
}
- stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID)
+ stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier)
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return nil
}
// possibly return all joined servers depending on history visiblity
- memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost)
+ memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost)
b.historyVisiblity = visibility
if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
@@ -488,7 +492,11 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents {
- if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
+ validRoomID, err := spec.NewRoomID(event.RoomID())
+ if err != nil {
+ continue
+ }
+ if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil {
serverSet[sender.Domain()] = true
}
}
@@ -554,7 +562,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
// pull all events and then filter by that table.
func joinEventsFromHistoryVisibility(
- ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
+ ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
var eventNIDs []types.EventNID
@@ -582,7 +590,7 @@ func joinEventsFromHistoryVisibility(
}
// Can we see events in the room?
- canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events)
+ canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events)
visibility := auth.HistoryVisibilityForRoom(events)
if !canSeeEvents {
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
@@ -597,7 +605,7 @@ func joinEventsFromHistoryVisibility(
return evs, visibility, err
}
-func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
+func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
var roomNID types.RoomNID
var eventNID types.EventNID
backfilledEventMap := make(map[string]types.Event)
@@ -639,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
continue
}
- resolver := state.NewStateResolution(db, roomInfo)
+ resolver := state.NewStateResolution(db, roomInfo, querier)
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver)
if err != nil {
diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go
index 121b257e..fd8055e0 100644
--- a/roomserver/internal/perform/perform_create_room.go
+++ b/roomserver/internal/perform/perform_create_room.go
@@ -63,13 +63,20 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
}
- senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
- if err != nil {
- util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
- return "", &util.JSONResponse{
- Code: http.StatusInternalServerError,
- JSON: spec.InternalServerError{},
+ var senderID spec.SenderID
+ if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
+ // create user room key if needed
+ key, keyErr := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
+ if keyErr != nil {
+ util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
+ return "", &util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.InternalServerError{},
+ }
}
+ senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
+ } else {
+ senderID = spec.SenderID(userID.String())
}
createContent["creator"] = senderID
createContent["room_version"] = createRequest.RoomVersion
@@ -323,8 +330,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
- if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return c.DB.GetUserIDForSender(ctx, roomID, senderID)
+ if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return "", &util.JSONResponse{
@@ -364,18 +371,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
- // create user room key if needed
- if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
- _, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
- if err != nil {
- util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
- return "", &util.JSONResponse{
- Code: http.StatusInternalServerError,
- JSON: spec.InternalServerError{},
- }
- }
- }
-
// send the remaining events
if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil {
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
@@ -455,7 +450,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
JSON: spec.InternalServerError{},
}
}
- inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID)
+ inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID)
if queryErr != nil {
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
return "", &util.JSONResponse{
diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go
index 3ac0f6f4..7fbec371 100644
--- a/roomserver/internal/perform/perform_inbound_peek.go
+++ b/roomserver/internal/perform/perform_inbound_peek.go
@@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek(
response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]}
// XXX: do we actually need to do a state resolution here?
- roomState := state.NewStateResolution(r.DB, info)
+ roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer)
var stateEntries []types.StateEntry
stateEntries, err = roomState.LoadStateAtSnapshot(
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index cc2c5c19..babd5f81 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -34,6 +34,7 @@ import (
type QueryState struct {
storage.Database
+ querier api.QuerySenderIDAPI
}
func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) {
@@ -46,7 +47,7 @@ func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWant
return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
}
if info != nil {
- roomState := state.NewStateResolution(q.Database, info)
+ roomState := state.NewStateResolution(q.Database, info, q.querier)
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
ctx, info.StateSnapshotNID(), stateWanted,
)
@@ -98,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership(
var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater
- userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
+ validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
+ if err != nil {
+ return nil, err
+ }
+ userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
}
@@ -126,7 +131,12 @@ func (r *Inviter) PerformInvite(
) error {
event := req.Event
- sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ validRoomID, err := spec.NewRoomID(event.RoomID())
+ if err != nil {
+ return err
+ }
+
+ sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
if err != nil {
return spec.InvalidParam("The sender user ID is invalid")
}
@@ -137,18 +147,13 @@ func (r *Inviter) PerformInvite(
if event.StateKey() == nil || *event.StateKey() == "" {
return fmt.Errorf("invite must be a state event")
}
- invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
+ invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
if err != nil || invitedUser == nil {
return spec.InvalidParam("Could not find the matching senderID for this user")
}
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
- validRoomID, err := spec.NewRoomID(event.RoomID())
- if err != nil {
- return err
- }
-
- invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser)
+ invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
if err != nil {
return fmt.Errorf("failed looking up senderID for invited user")
}
@@ -161,9 +166,9 @@ func (r *Inviter) PerformInvite(
IsTargetLocal: isTargetLocal,
StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
- StateQuerier: &QueryState{r.DB},
- UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ StateQuerier: &QueryState{r.DB, r.RSAPI},
+ UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
},
}
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go
index 74ed87c7..5867ee6e 100644
--- a/roomserver/internal/perform/perform_join.go
+++ b/roomserver/internal/perform/perform_join.go
@@ -25,6 +25,7 @@ import (
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
+ "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -174,44 +175,6 @@ func (r *Joiner) performJoinRoomByID(
req.ServerNames = append(req.ServerNames, roomID.Domain())
}
- // Prepare the template for the join event.
- userID, err := spec.NewUserID(req.UserID, true)
- if err != nil {
- return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
- }
- senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID)
- if err != nil {
- return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
- }
- senderIDString := string(senderID)
- userDomain := userID.Domain()
- proto := gomatrixserverlib.ProtoEvent{
- Type: spec.MRoomMember,
- SenderID: senderIDString,
- StateKey: &senderIDString,
- RoomID: req.RoomIDOrAlias,
- Redacts: "",
- }
- if err = proto.SetUnsigned(struct{}{}); err != nil {
- return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
- }
-
- // It is possible for the request to include some "content" for the
- // event. We'll always overwrite the "membership" key, but the rest,
- // like "display_name" or "avatar_url", will be kept if supplied.
- if req.Content == nil {
- req.Content = map[string]interface{}{}
- }
- req.Content["membership"] = spec.Join
- if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
- return "", "", aerr
- } else if authorisedVia != "" {
- req.Content["join_authorised_via_users_server"] = authorisedVia
- }
- if err = proto.SetContent(req.Content); err != nil {
- return "", "", fmt.Errorf("eb.SetContent: %w", err)
- }
-
// Force a federated join if we aren't in the room and we've been
// given some server names to try joining by.
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
@@ -224,29 +187,63 @@ func (r *Joiner) performJoinRoomByID(
serverInRoom := inRoomRes.IsInRoom
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
+ userID, err := spec.NewUserID(req.UserID, true)
+ if err != nil {
+ return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
+ }
+
+ // Look up the room NID for the supplied room ID.
+ var senderID spec.SenderID
+ checkInvitePending := false
+ info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
+ if err == nil && info != nil {
+ switch info.RoomVersion {
+ case gomatrixserverlib.RoomVersionPseudoIDs:
+ senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
+ if err == nil {
+ checkInvitePending = true
+ } else {
+ // create user room key if needed
+ key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
+ if keyErr != nil {
+ util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
+ return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr)
+ }
+ senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
+ }
+ default:
+ checkInvitePending = true
+ senderID = spec.SenderID(userID.String())
+ }
+ }
+
+ userDomain := userID.Domain()
+
// Force a federated join if we're dealing with a pending invite
// and we aren't in the room.
- isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
- if err == nil && !serverInRoom && isInvitePending {
- inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
- if queryErr != nil {
- return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
- }
+ if checkInvitePending {
+ isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
+ if inviteErr == nil && !serverInRoom && isInvitePending {
+ inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
+ if queryErr != nil {
+ return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
+ }
- // If we were invited by someone from another server then we can
- // assume they are in the room so we can join via them.
- if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
- req.ServerNames = append(req.ServerNames, inviter.Domain())
- forceFederatedJoin = true
- memberEvent := gjson.Parse(string(inviteEvent.JSON()))
- // only set unsigned if we've got a content.membership, which we _should_
- if memberEvent.Get("content.membership").Exists() {
- req.Unsigned = map[string]interface{}{
- "prev_sender": memberEvent.Get("sender").Str,
- "prev_content": map[string]interface{}{
- "is_direct": memberEvent.Get("content.is_direct").Bool(),
- "membership": memberEvent.Get("content.membership").Str,
- },
+ // If we were invited by someone from another server then we can
+ // assume they are in the room so we can join via them.
+ if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
+ req.ServerNames = append(req.ServerNames, inviter.Domain())
+ forceFederatedJoin = true
+ memberEvent := gjson.Parse(string(inviteEvent.JSON()))
+ // only set unsigned if we've got a content.membership, which we _should_
+ if memberEvent.Get("content.membership").Exists() {
+ req.Unsigned = map[string]interface{}{
+ "prev_sender": memberEvent.Get("sender").Str,
+ "prev_content": map[string]interface{}{
+ "is_direct": memberEvent.Get("content.is_direct").Bool(),
+ "membership": memberEvent.Get("content.membership").Str,
+ },
+ }
}
}
}
@@ -274,6 +271,7 @@ func (r *Joiner) performJoinRoomByID(
// If we should do a forced federated join then do that.
var joinedVia spec.ServerName
if forceFederatedJoin {
+ // TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet
joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
return req.RoomIDOrAlias, joinedVia, err
}
@@ -289,19 +287,40 @@ func (r *Joiner) performJoinRoomByID(
if err != nil {
return "", "", fmt.Errorf("error joining local room: %q", err)
}
+
+ senderIDString := string(senderID)
+
+ // Prepare the template for the join event.
+ proto := gomatrixserverlib.ProtoEvent{
+ Type: spec.MRoomMember,
+ SenderID: senderIDString,
+ StateKey: &senderIDString,
+ RoomID: req.RoomIDOrAlias,
+ Redacts: "",
+ }
+ if err = proto.SetUnsigned(struct{}{}); err != nil {
+ return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
+ }
+
+ // It is possible for the request to include some "content" for the
+ // event. We'll always overwrite the "membership" key, but the rest,
+ // like "display_name" or "avatar_url", will be kept if supplied.
+ if req.Content == nil {
+ req.Content = map[string]interface{}{}
+ }
+ req.Content["membership"] = spec.Join
+ if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
+ return "", "", aerr
+ } else if authorisedVia != "" {
+ req.Content["join_authorised_via_users_server"] = authorisedVia
+ }
+ if err = proto.SetContent(req.Content); err != nil {
+ return "", "", fmt.Errorf("eb.SetContent: %w", err)
+ }
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes)
switch err.(type) {
case nil:
- // create user room key if needed
- if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
- _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
- if err != nil {
- logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
- return "", "", fmt.Errorf("failed to get user room private key: %w", err)
- }
- }
-
// The room join is local. Send the new join event into the
// roomserver. First of all check that the user isn't already
// a member of the room. This is best-effort (as in we won't
diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go
index 1b23cc1f..e1ddb9b5 100644
--- a/roomserver/internal/perform/perform_leave.go
+++ b/roomserver/internal/perform/perform_leave.go
@@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID(
req *api.PerformLeaveRequest,
res *api.PerformLeaveResponse, // nolint:unparam
) ([]api.OutputEvent, error) {
- leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver)
+ roomID, err := spec.NewRoomID(req.RoomID)
+ if err != nil {
+ return nil, err
+ }
+ leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
if err != nil {
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
}
@@ -87,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID(
// that.
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
if err == nil && isInvitePending {
- sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
+ sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
if serr != nil || sender == nil {
return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
}
@@ -133,7 +137,7 @@ func (r *Leaver) performLeaveRoomByID(
},
}
latestRes := api.QueryLatestEventsAndStateResponse{}
- if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil {
+ if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
return nil, err
}
if !latestRes.RoomExists {
diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go
index 1aaa42c9..32f547dc 100644
--- a/roomserver/internal/perform/perform_upgrade.go
+++ b/roomserver/internal/perform/perform_upgrade.go
@@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade(
return "", err
}
- senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID)
+ fullRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return "", err
+ }
+ senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
return "", err
@@ -488,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send
}
- if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
@@ -569,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
- if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
}); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index caea6b52..19fd456b 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -16,6 +16,7 @@ package query
import (
"context"
+ "crypto/ed25519"
"database/sql"
"errors"
"fmt"
@@ -89,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
- return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response)
+ return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response)
}
// QueryStateAfterEvents implements api.RoomserverInternalAPI
@@ -106,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents(
return nil
}
- roomState := state.NewStateResolution(r.DB, info)
+ roomState := state.NewStateResolution(r.DB, info, r)
response.RoomExists = true
response.RoomVersion = info.RoomVersion
@@ -159,8 +160,8 @@ func (r *Queryer) QueryStateAfterEvents(
}
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
},
)
if err != nil {
@@ -271,15 +272,15 @@ func (r *Queryer) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
- senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
+ roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
-
- roomID, err := spec.NewRoomID(request.RoomID)
+ senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
+
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
}
@@ -320,7 +321,7 @@ func (r *Queryer) QueryMembershipAtEvent(
}
response.Membership = make(map[string]*types.HeaderedEvent)
- stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
+ stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
@@ -407,7 +408,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
@@ -445,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
} else {
- stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
+ stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
@@ -458,7 +459,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
@@ -532,7 +533,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
}
return helpers.CheckServerAllowedToSeeEvent(
- ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
+ ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r,
)
}
@@ -573,7 +574,7 @@ func (r *Queryer) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
}
- resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
+ resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r)
if err != nil {
return err
}
@@ -651,8 +652,8 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
},
)
if err != nil {
@@ -673,7 +674,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// first bool: is rejected, second bool: state missing
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
- roomState := state.NewStateResolution(r.DB, roomInfo)
+ roomState := state.NewStateResolution(r.DB, roomInfo, r)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil {
switch err.(type) {
@@ -989,10 +990,46 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
}
-func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
- return r.DB.GetSenderIDForUser(ctx, roomID, userID)
+func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
+ version, err := r.DB.GetRoomVersion(ctx, roomID.String())
+ if err != nil {
+ return "", err
+ }
+
+ switch version {
+ case gomatrixserverlib.RoomVersionPseudoIDs:
+ key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
+ if err != nil {
+ return "", err
+ }
+ return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
+ default:
+ return spec.SenderID(userID.String()), nil
+ }
}
-func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ userID, err := spec.NewUserID(string(senderID), true)
+ if err == nil {
+ return userID, nil
+ }
+
+ bytes := spec.Base64Bytes{}
+ err = bytes.Decode(string(senderID))
+ if err != nil {
+ return nil, err
+ }
+ queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}}
+ result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap)
+ if err != nil {
+ return nil, err
+ }
+
+ if userKeys, ok := result[roomID]; ok {
+ if userID, ok := userKeys[string(senderID)]; ok {
+ return spec.NewUserID(userID, true)
+ }
+ }
+
+ return nil, nil
}
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index 90c94bbc..077957fa 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -516,6 +516,9 @@ func TestRedaction(t *testing.T) {
t.Fatal(err)
}
+ natsInstance := &jetstream.NATSInstance{}
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
+
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
authEvents := []types.EventNID{}
@@ -551,7 +554,7 @@ func TestRedaction(t *testing.T) {
}
// Calculate the snapshotNID etc.
- plResolver := state.NewStateResolution(db, roomInfo)
+ plResolver := state.NewStateResolution(db, roomInfo, rsAPI)
stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false)
assert.NoError(t, err)
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index b9c5bbc4..1e776ff6 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -29,6 +29,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
)
@@ -44,20 +45,21 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
- GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
}
type StateResolution struct {
db StateResolutionStorage
roomInfo *types.RoomInfo
events map[types.EventNID]gomatrixserverlib.PDU
+ Querier api.QuerySenderIDAPI
}
-func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
+func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution {
return StateResolution{
db: db,
roomInfo: roomInfo,
events: make(map[types.EventNID]gomatrixserverlib.PDU),
+ Querier: querier,
}
}
@@ -947,8 +949,8 @@ func (v *StateResolution) resolveConflictsV1(
}
// Resolve the conflicts.
- resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return v.db.GetUserIDForSender(ctx, roomID, senderID)
+ resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
})
// Map from the full events back to numeric state entries.
@@ -1061,8 +1063,8 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents,
nonConflictedEvents,
authEvents,
- func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return v.db.GetUserIDForSender(ctx, roomID, senderID)
+ func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
},
)
}()
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 7787d9f8..7156c11c 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -169,10 +169,6 @@ type Database interface {
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
// GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
- // GetKnownUsers tries to obtain the current mxid for a given user.
- GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
- // GetKnownUsers tries to obtain the current senderID for a given user.
- GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
// GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error)
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
@@ -190,6 +186,7 @@ type Database interface {
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*types.HeaderedEvent, error)
GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error)
+ GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
MaybeRedactEvent(
@@ -205,8 +202,12 @@ type UserRoomKeys interface {
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
+ // SelectUserRoomPublicKey selects the public key for the given user and room combination
+ SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error)
// SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID.
// If a senderKey can't be found, it is omitted in the result.
+ // TODO: Why is the result map indexed by string not public key?
+ // TODO: Shouldn't the input & result map be changed to be indexed by string instead of the RoomID struct?
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)
}
@@ -233,7 +234,6 @@ type RoomDatabase interface {
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
- GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
}
type EventDatabase interface {
diff --git a/roomserver/storage/postgres/user_room_keys_table.go b/roomserver/storage/postgres/user_room_keys_table.go
index 22f978bf..dbb4af34 100644
--- a/roomserver/storage/postgres/user_room_keys_table.go
+++ b/roomserver/storage/postgres/user_room_keys_table.go
@@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
+const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
+
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt
+ selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt
}
@@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
+ {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
}.Prepare(db)
}
@@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err
}
+func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
+ ctx context.Context,
+ txn *sql.Tx,
+ userNID types.EventStateKeyNID,
+ roomNID types.RoomNID,
+) (ed25519.PublicKey, error) {
+ stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
+ var result ed25519.PublicKey
+ err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return result, err
+}
+
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)
diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go
index 6fb57332..70672a33 100644
--- a/roomserver/storage/shared/room_updater.go
+++ b/roomserver/storage/shared/room_updater.go
@@ -6,7 +6,6 @@ import (
"fmt"
"github.com/matrix-org/gomatrixserverlib"
- "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/dendrite/roomserver/types"
)
@@ -251,7 +250,3 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}
-
-func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return u.d.GetUserIDForSender(ctx, roomID, senderID)
-}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index bda51da8..61a3520a 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -721,6 +721,22 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver
}, err
}
+func (d *Database) GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) {
+ cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(roomID)
+ if versionOK {
+ return cachedRoomVersion, nil
+ }
+
+ roomInfo, err := d.RoomInfo(ctx, roomID)
+ if err != nil {
+ return "", err
+ }
+ if roomInfo == nil {
+ return "", nil
+ }
+ return roomInfo.RoomVersion, nil
+}
+
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil {
@@ -1550,16 +1566,6 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
}
-func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- // TODO: Use real logic once DB for pseudoIDs is in place
- return spec.NewUserID(string(senderID), true)
-}
-
-func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
- // TODO: Use real logic once DB for pseudoIDs is in place
- return spec.SenderID(userID.String()), nil
-}
-
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
@@ -1718,6 +1724,35 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use
return
}
+// SelectUserRoomPublicKey queries the users room public key.
+// If no key exists, returns no key and no error. Otherwise returns
+// the key and a database error, if any.
+func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) {
+ uID := userID.String()
+ stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
+ if sErr != nil {
+ return nil, sErr
+ }
+ stateKeyNID := stateKeyNIDMap[uID]
+
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
+ if rErr != nil {
+ return rErr
+ }
+ if roomInfo == nil {
+ return nil
+ }
+
+ key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID)
+ if !errors.Is(sErr, sql.ErrNoRows) {
+ return sErr
+ }
+ return nil
+ })
+ return
+}
+
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
result = make(map[spec.RoomID]map[string]string, len(publicKeys))
diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go
index 581d83ee..c7b915c7 100644
--- a/roomserver/storage/shared/storage_test.go
+++ b/roomserver/storage/shared/storage_test.go
@@ -163,12 +163,17 @@ func TestUserRoomKeys(t *testing.T) {
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID)
assert.NoError(t, err)
assert.Equal(t, key, gotKey)
+ pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID)
+ assert.NoError(t, err)
+ assert.Equal(t, key.Public(), pubKey)
// Key doesn't exist, we shouldn't get anything back
- assert.NoError(t, err)
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist)
assert.NoError(t, err)
assert.Nil(t, gotKey)
+ pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist)
+ assert.NoError(t, err)
+ assert.Nil(t, pubKey)
queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
*roomID: {key.Public().(ed25519.PublicKey)},
diff --git a/roomserver/storage/sqlite3/user_room_keys_table.go b/roomserver/storage/sqlite3/user_room_keys_table.go
index 8af57ea0..84c8b54e 100644
--- a/roomserver/storage/sqlite3/user_room_keys_table.go
+++ b/roomserver/storage/sqlite3/user_room_keys_table.go
@@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
+const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
+
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt
+ selectUserRoomPublicKeyStmt *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
}
@@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
+ {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db)
}
@@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
return result, err
}
+func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
+ ctx context.Context,
+ txn *sql.Tx,
+ userNID types.EventStateKeyNID,
+ roomNID types.RoomNID,
+) (ed25519.PublicKey, error) {
+ stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
+ var result ed25519.PublicKey
+ err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return result, err
+}
+
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
roomNIDs := make([]any, 0, len(senderKeys))
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index cd0e5168..445c1223 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -193,6 +193,8 @@ type UserRoomKeys interface {
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
+ // SelectUserRoomPublicKey selects the public key for the given user and room combination
+ SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error)
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
diff --git a/roomserver/storage/tables/user_room_keys_table_test.go b/roomserver/storage/tables/user_room_keys_table_test.go
index 28430948..8802a3c6 100644
--- a/roomserver/storage/tables/user_room_keys_table_test.go
+++ b/roomserver/storage/tables/user_room_keys_table_test.go
@@ -50,6 +50,7 @@ func TestUserRoomKeysTable(t *testing.T) {
err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
var gotKey, key2, key3 ed25519.PrivateKey
+ var pubKey ed25519.PublicKey
gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key)
assert.NoError(t, err)
assert.Equal(t, gotKey, key)
@@ -71,6 +72,9 @@ func TestUserRoomKeysTable(t *testing.T) {
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID)
assert.NoError(t, err)
assert.Equal(t, key, gotKey)
+ pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID)
+ assert.NoError(t, err)
+ assert.Equal(t, key.Public(), pubKey)
// try to update an existing key, this should only be done for users NOT on this homeserver
var gotPubKey ed25519.PublicKey
@@ -82,6 +86,9 @@ func TestUserRoomKeysTable(t *testing.T) {
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2)
assert.NoError(t, err)
assert.Nil(t, gotKey)
+ pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2)
+ assert.NoError(t, err)
+ assert.Nil(t, pubKey)
// query user NIDs for senderKeys
var gotKeys map[string]types.UserRoomKeyPair