aboutsummaryrefslogtreecommitdiff
path: root/roomserver/state
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/state')
-rw-r--r--roomserver/state/state.go60
1 files changed, 57 insertions, 3 deletions
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index ca0c69f2..a40a2e9b 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -23,12 +23,11 @@ import (
"sync"
"time"
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
-
- "github.com/matrix-org/dendrite/roomserver/types"
- "github.com/matrix-org/gomatrixserverlib"
)
type StateResolutionStorage interface {
@@ -124,6 +123,61 @@ func (v *StateResolution) LoadStateAtEvent(
return stateEntries, nil
}
+func (v *StateResolution) LoadMembershipAtEvent(
+ ctx context.Context, eventIDs []string, stateKeyNID types.EventStateKeyNID,
+) (map[string][]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadMembershipAtEvent")
+ defer span.Finish()
+
+ // De-dupe snapshotNIDs
+ snapshotNIDMap := make(map[types.StateSnapshotNID][]string) // map from snapshot NID to eventIDs
+ for i := range eventIDs {
+ eventID := eventIDs[i]
+ snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
+ if err != nil {
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %w", eventID, err)
+ }
+ if snapshotNID == 0 {
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
+ }
+ snapshotNIDMap[snapshotNID] = append(snapshotNIDMap[snapshotNID], eventID)
+ }
+
+ snapshotNIDs := make([]types.StateSnapshotNID, 0, len(snapshotNIDMap))
+ for nid := range snapshotNIDMap {
+ snapshotNIDs = append(snapshotNIDs, nid)
+ }
+
+ stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, snapshotNIDs)
+ if err != nil {
+ return nil, err
+ }
+
+ result := make(map[string][]types.StateEntry)
+ for _, stateBlockNIDList := range stateBlockNIDLists {
+ // Query the membership event for the user at the given stateblocks
+ stateEntryLists, err := v.db.StateEntriesForTuples(ctx, stateBlockNIDList.StateBlockNIDs, []types.StateKeyTuple{
+ {
+ EventTypeNID: types.MRoomMemberNID,
+ EventStateKeyNID: stateKeyNID,
+ },
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ evIDs := snapshotNIDMap[stateBlockNIDList.StateSnapshotNID]
+
+ for _, evID := range evIDs {
+ for _, x := range stateEntryLists {
+ result[evID] = append(result[evID], x.StateEntries...)
+ }
+ }
+ }
+
+ return result, nil
+}
+
// LoadStateAtEvent loads the full state of a room before a particular event.
func (v *StateResolution) LoadStateAtEventForHistoryVisibility(
ctx context.Context, eventID string,