aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r--roomserver/internal/query/query.go70
1 files changed, 42 insertions, 28 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index ae2b7cf5..caea6b52 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -48,7 +48,7 @@ type Queryer struct {
Cfg *config.Dendrite
}
-func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, userID spec.UserID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
+func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, localServerName spec.ServerName) (*gomatrixserverlib.RestrictedRoomJoinInfo, error) {
roomInfo, err := r.QueryRoomInfo(ctx, roomID)
if err != nil || roomInfo == nil || roomInfo.IsStub() {
return nil, err
@@ -64,7 +64,7 @@ func (r *Queryer) RestrictedRoomJoinInfo(ctx context.Context, roomID spec.RoomID
return nil, fmt.Errorf("InternalServerError: Failed to query room: %w", err)
}
- userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), userID)
+ userJoinedToRoom, err := r.UserJoinedToRoom(ctx, types.RoomNID(roomInfo.RoomNID), senderID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("rsAPI.UserJoinedToRoom failed")
return nil, fmt.Errorf("InternalServerError: %w", err)
@@ -220,13 +220,14 @@ func (r *Queryer) QueryEventsByID(
return nil
}
-// QueryMembershipForUser implements api.RoomserverInternalAPI
-func (r *Queryer) QueryMembershipForUser(
+// QueryMembershipForSenderID implements api.RoomserverInternalAPI
+func (r *Queryer) QueryMembershipForSenderID(
ctx context.Context,
- request *api.QueryMembershipForUserRequest,
+ roomID spec.RoomID,
+ senderID spec.SenderID,
response *api.QueryMembershipForUserResponse,
) error {
- info, err := r.DB.RoomInfo(ctx, request.RoomID)
+ info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return err
}
@@ -236,7 +237,7 @@ func (r *Queryer) QueryMembershipForUser(
}
response.RoomExists = true
- membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID)
+ membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
if err != nil {
return err
}
@@ -264,6 +265,24 @@ func (r *Queryer) QueryMembershipForUser(
return err
}
+// QueryMembershipForUser implements api.RoomserverInternalAPI
+func (r *Queryer) QueryMembershipForUser(
+ ctx context.Context,
+ request *api.QueryMembershipForUserRequest,
+ response *api.QueryMembershipForUserResponse,
+) error {
+ senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
+ if err != nil {
+ return err
+ }
+
+ roomID, err := spec.NewRoomID(request.RoomID)
+ if err != nil {
+ return err
+ }
+ return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
+}
+
// QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned
// for that event instead.
@@ -373,7 +392,7 @@ func (r *Queryer) QueryMembershipsForRoom(
// If no sender is specified then we will just return the entire
// set of memberships for the room, regardless of whether a specific
// user is allowed to see them or not.
- if request.Sender == "" {
+ if request.SenderID == "" {
var events []types.Event
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, request.LocalOnly)
@@ -388,18 +407,15 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- sender := spec.UserID{}
- userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
- if queryErr == nil && userID != nil {
- sender = *userID
- }
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
+ clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
+ }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
}
- membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender)
+ membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, request.SenderID)
if err != nil {
return err
}
@@ -442,12 +458,9 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- sender := spec.UserID{}
- userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
+ clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
+ }, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
@@ -489,6 +502,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context,
serverName spec.ServerName,
eventID string,
+ roomID string,
) (allowed bool, err error) {
events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil {
@@ -518,7 +532,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
}
return helpers.CheckServerAllowedToSeeEvent(
- ctx, r.DB, info, eventID, serverName, isInRoom,
+ ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
)
}
@@ -909,8 +923,8 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
return nil
}
-func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (bool, error) {
- pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), userID.String())
+func (r *Queryer) InvitePending(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (bool, error) {
+ pending, _, _, _, err := helpers.IsInvitePending(ctx, r.DB, roomID.String(), senderID)
return pending, err
}
@@ -926,8 +940,8 @@ func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eve
return res, err
}
-func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, userID spec.UserID) (bool, error) {
- _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, userID.String())
+func (r *Queryer) UserJoinedToRoom(ctx context.Context, roomNID types.RoomNID, senderID spec.SenderID) (bool, error) {
+ _, isIn, _, err := r.DB.GetMembership(ctx, roomNID, senderID)
return isIn, err
}
@@ -957,7 +971,7 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse
}
// nolint:gocyclo
-func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) {
+func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) {
// Look up if we know anything about the room. If it doesn't exist
// or is a stub entry then we can't do anything.
roomInfo, err := r.DB.RoomInfo(ctx, roomID.String())
@@ -972,7 +986,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return "", err
}
- return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
+ return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
}
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {