aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/helpers/helpers.go
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-06-14 14:23:46 +0000
committerGitHub <noreply@github.com>2023-06-14 14:23:46 +0000
commite4665979bfbe006368d55189f074e456fe19b198 (patch)
treee909d694a022478d0dbe3cc58ee8a2dc289bc969 /roomserver/internal/helpers/helpers.go
parent7a2e325d1014d76188b47a011730a42443f3c174 (diff)
Merge SenderID & Per Room User Key work (#3109)
Diffstat (limited to 'roomserver/internal/helpers/helpers.go')
-rw-r--r--roomserver/internal/helpers/helpers.go38
1 files changed, 21 insertions, 17 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index 263cb9f8..febabf41 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -68,7 +68,7 @@ func UpdateToInviteMembership(
// memberships. If the servername is not supplied then the local server will be
// checked instead using a faster code path.
// TODO: This should probably be replaced by an API call.
-func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) {
+func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) {
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
return false, err
@@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
for i := range events {
gmslEvents[i] = events[i].PDU
}
- return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
+ return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil
}
func IsInvitePending(
@@ -211,8 +211,8 @@ func GetMembershipsAtState(
return events, nil
}
-func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
- roomState := state.NewStateResolution(db, info)
+func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) {
+ roomState := state.NewStateResolution(db, info, querier)
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
@@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
-func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
- roomState := state.NewStateResolution(db, info)
+func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) {
+ roomState := state.NewStateResolution(db, info, querier)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}
@@ -264,7 +264,7 @@ func LoadStateEvents(
}
func CheckServerAllowedToSeeEvent(
- ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI,
) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err {
@@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using
// the old and slow method
- stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
+ stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier)
if err != nil {
return false, err
}
@@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent(
return false, err
}
}
- return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
+ return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil
}
func slowGetHistoryVisibilityState(
- ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]gomatrixserverlib.PDU, error) {
- roomState := state.NewStateResolution(db, info)
+ roomState := state.NewStateResolution(db, info, querier)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState(
// If the event state key doesn't match the given servername
// then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc.
+ validRoomID, err := spec.NewRoomID(roomID)
+ if err != nil {
+ return nil, err
+ }
for nid, key := range stateKeys {
if key != "" {
- userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
+ userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key))
if err == nil && userID != nil {
if userID.Domain() != serverName {
delete(stateKeys, nid)
@@ -349,7 +353,7 @@ func slowGetHistoryVisibilityState(
// TODO: Remove this when we have tests to assert correctness of this function
func ScanEventTree(
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
- serverName spec.ServerName,
+ serverName spec.ServerName, querier api.QuerySenderIDAPI,
) ([]types.EventNID, map[string]struct{}, error) {
var resultNIDs []types.EventNID
var err error
@@ -392,7 +396,7 @@ BFSLoop:
// It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!)
ev := events[0]
- isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
+ isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
}
@@ -415,7 +419,7 @@ BFSLoop:
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
- allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
+ allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
@@ -444,7 +448,7 @@ BFSLoop:
}
func QueryLatestEventsAndState(
- ctx context.Context, db storage.Database,
+ ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI,
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
@@ -457,7 +461,7 @@ func QueryLatestEventsAndState(
return nil
}
- roomState := state.NewStateResolution(db, roomInfo)
+ roomState := state.NewStateResolution(db, roomInfo, querier)
response.RoomExists = true
response.RoomVersion = roomInfo.RoomVersion