diff options
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r-- | roomserver/internal/query/query.go | 26 |
1 files changed, 19 insertions, 7 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 69d841dd..1083bb23 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -216,7 +217,8 @@ func (r *Queryer) QueryMembershipAtEvent( request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse, ) error { - response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return fmt.Errorf("unable to get roomInfo: %w", err) @@ -234,7 +236,17 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) } - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID]) + response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...) + switch err { + case nil: + return nil + case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event + default: + return err + } + + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -258,7 +270,7 @@ func (r *Queryer) QueryMembershipAtEvent( for _, eventID := range request.EventIDs { stateEntry, ok := stateEntries[eventID] if !ok || len(stateEntry) == 0 { - response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} + response.Membership[eventID] = nil continue } @@ -275,15 +287,15 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("unable to get memberships at state: %w", err) } - res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) - + // Iterate over all membership events we got. Given we only query the membership for + // one user and assuming this user only ever has one membership event associated to + // a given event, overwrite any other existing membership events. for i := range memberships { ev := memberships[i] if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { - res = append(res, ev.Headered(info.RoomVersion)) + response.Membership[eventID] = ev.Event.Headered(info.RoomVersion) } } - response.Memberships[eventID] = res } return nil |