aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/helpers/helpers.go
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-06-12 11:19:25 +0000
committerGitHub <noreply@github.com>2023-06-12 11:19:25 +0000
commit77d9e4e93dd01f6baa82bd6236850c1007346cac (patch)
tree20be66224646cc82199028cf89f4cd7fab80b97f /roomserver/internal/helpers/helpers.go
parent832ccc32f6a023665e250eee44b5f678e985d50e (diff)
Cleanup remaining statekey usage for senderIDs (#3106)
Diffstat (limited to 'roomserver/internal/helpers/helpers.go')
-rw-r--r--roomserver/internal/helpers/helpers.go37
1 files changed, 21 insertions, 16 deletions
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index 95397cd5..263cb9f8 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"sort"
- "strings"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
@@ -55,9 +54,10 @@ func UpdateToInviteMembership(
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
+ RoomID: add.RoomID(),
Membership: spec.Join,
RetiredByEventID: add.EventID(),
- TargetUserID: *add.StateKey(),
+ TargetSenderID: spec.SenderID(*add.StateKey()),
},
})
}
@@ -94,13 +94,13 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
for i := range events {
gmslEvents[i] = events[i].PDU
}
- return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, spec.Join), nil
+ return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
}
func IsInvitePending(
ctx context.Context, db storage.Database,
- roomID, userID string,
-) (bool, string, string, gomatrixserverlib.PDU, error) {
+ roomID string, senderID spec.SenderID,
+) (bool, spec.SenderID, string, gomatrixserverlib.PDU, error) {
// Look up the room NID for the supplied room ID.
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
@@ -111,13 +111,13 @@ func IsInvitePending(
}
// Look up the state key NID for the supplied user ID.
- targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
+ targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil {
return false, "", "", nil, fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
}
- targetUserNID, targetUserFound := targetUserNIDs[userID]
+ targetUserNID, targetUserFound := targetUserNIDs[string(senderID)]
if !targetUserFound {
- return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
+ return false, "", "", nil, fmt.Errorf("missing NID for user %q (%+v)", senderID, targetUserNIDs)
}
// Let's see if we have an event active for the user in the room. If
@@ -156,7 +156,7 @@ func IsInvitePending(
event, err := verImpl.NewEventFromTrustedJSON(eventJSON, false)
- return true, senderUser, userNIDToEventID[senderUserNIDs[0]], event, err
+ return true, spec.SenderID(senderUser), userNIDToEventID[senderUserNIDs[0]], event, err
}
// GetMembershipsAtState filters the state events to
@@ -264,7 +264,7 @@ func LoadStateEvents(
}
func CheckServerAllowedToSeeEvent(
- ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName, isServerInRoom bool,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
) (bool, error) {
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
switch err {
@@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
case tables.OptimisationNotSupportedError:
// The database engine didn't support this optimisation, so fall back to using
// the old and slow method
- stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, eventID, serverName)
+ stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
if err != nil {
return false, err
}
@@ -288,11 +288,11 @@ func CheckServerAllowedToSeeEvent(
return false, err
}
}
- return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
+ return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
}
func slowGetHistoryVisibilityState(
- ctx context.Context, db storage.Database, info *types.RoomInfo, eventID string, serverName spec.ServerName,
+ ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
) ([]gomatrixserverlib.PDU, error) {
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
@@ -319,8 +319,13 @@ func slowGetHistoryVisibilityState(
// then we'll filter it out. This does preserve state keys that
// are "" since these will contain history visibility etc.
for nid, key := range stateKeys {
- if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
- delete(stateKeys, nid)
+ if key != "" {
+ userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
+ if err == nil && userID != nil {
+ if userID.Domain() != serverName {
+ delete(stateKeys, nid)
+ }
+ }
}
}
@@ -410,7 +415,7 @@ BFSLoop:
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
- allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
+ allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",