aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r--roomserver/internal/query/query.go26
1 files changed, 19 insertions, 7 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 69d841dd..1083bb23 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -21,6 +21,7 @@ import (
"errors"
"fmt"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -216,7 +217,8 @@ func (r *Queryer) QueryMembershipAtEvent(
request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse,
) error {
- response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent)
+ response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent)
+
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err)
@@ -234,7 +236,17 @@ func (r *Queryer) QueryMembershipAtEvent(
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
}
- stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID])
+ response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...)
+ switch err {
+ case nil:
+ return nil
+ case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
+ default:
+ return err
+ }
+
+ response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent)
+ stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
@@ -258,7 +270,7 @@ func (r *Queryer) QueryMembershipAtEvent(
for _, eventID := range request.EventIDs {
stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 {
- response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{}
+ response.Membership[eventID] = nil
continue
}
@@ -275,15 +287,15 @@ func (r *Queryer) QueryMembershipAtEvent(
return fmt.Errorf("unable to get memberships at state: %w", err)
}
- res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships))
-
+ // Iterate over all membership events we got. Given we only query the membership for
+ // one user and assuming this user only ever has one membership event associated to
+ // a given event, overwrite any other existing membership events.
for i := range memberships {
ev := memberships[i]
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) {
- res = append(res, ev.Headered(info.RoomVersion))
+ response.Membership[eventID] = ev.Event.Headered(info.RoomVersion)
}
}
- response.Memberships[eventID] = res
}
return nil