aboutsummaryrefslogtreecommitdiff
path: root/roomserver/state
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-06-07 14:23:26 +0100
committerNeil Alexander <neilalexander@users.noreply.github.com>2022-06-07 14:23:26 +0100
commit27948fb30468315ce613402dc8cc1fa7dba01679 (patch)
treec619df228460d99aee4d19e629d483fb9e3db002 /roomserver/state
parentaafb7bf120d30c37219686a5bb528794b0ab44a2 (diff)
Optimise `loadAuthEvents`, add roomserver tracing
Diffstat (limited to 'roomserver/state')
-rw-r--r--roomserver/state/state.go219
1 files changed, 170 insertions, 49 deletions
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index 95abdcb3..6c4e4b86 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -20,9 +20,11 @@ import (
"context"
"fmt"
"sort"
+ "sync"
"time"
"github.com/matrix-org/util"
+ "github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/roomserver/types"
@@ -62,6 +64,9 @@ func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) Sta
func (v *StateResolution) LoadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshot")
+ defer span.Finish()
+
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil {
return nil, err
@@ -100,6 +105,9 @@ func (v *StateResolution) LoadStateAtSnapshot(
func (v *StateResolution) LoadStateAtEvent(
ctx context.Context, eventID string,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent")
+ defer span.Finish()
+
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
@@ -122,6 +130,9 @@ func (v *StateResolution) LoadStateAtEvent(
func (v *StateResolution) LoadCombinedStateAfterEvents(
ctx context.Context, prevStates []types.StateAtEvent,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadCombinedStateAfterEvents")
+ defer span.Finish()
+
stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
for i, state := range prevStates {
stateNIDs[i] = state.BeforeStateSnapshotNID
@@ -194,6 +205,9 @@ func (v *StateResolution) LoadCombinedStateAfterEvents(
func (v *StateResolution) DifferenceBetweeenStateSnapshots(
ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID,
) (removed, added []types.StateEntry, err error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.DifferenceBetweeenStateSnapshots")
+ defer span.Finish()
+
if oldStateNID == newStateNID {
// If the snapshot NIDs are the same then nothing has changed
return nil, nil, nil
@@ -255,6 +269,9 @@ func (v *StateResolution) LoadStateAtSnapshotForStringTuples(
stateNID types.StateSnapshotNID,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples")
+ defer span.Finish()
+
numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
if err != nil {
return nil, err
@@ -269,6 +286,9 @@ func (v *StateResolution) stringTuplesToNumericTuples(
ctx context.Context,
stringTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateKeyTuple, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.stringTuplesToNumericTuples")
+ defer span.Finish()
+
eventTypes := make([]string, len(stringTuples))
stateKeys := make([]string, len(stringTuples))
for i := range stringTuples {
@@ -311,6 +331,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples(
stateNID types.StateSnapshotNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples")
+ defer span.Finish()
+
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil {
return nil, err
@@ -359,6 +382,9 @@ func (v *StateResolution) LoadStateAfterEventsForStringTuples(
prevStates []types.StateAtEvent,
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAfterEventsForStringTuples")
+ defer span.Finish()
+
numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples)
if err != nil {
return nil, err
@@ -371,6 +397,9 @@ func (v *StateResolution) loadStateAfterEventsForNumericTuples(
prevStates []types.StateAtEvent,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAfterEventsForNumericTuples")
+ defer span.Finish()
+
if len(prevStates) == 1 {
// Fast path for a single event.
prevState := prevStates[0]
@@ -543,6 +572,9 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent(
event *gomatrixserverlib.Event,
isRejected bool,
) (types.StateSnapshotNID, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent")
+ defer span.Finish()
+
// Load the state at the prev events.
prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs())
if err != nil {
@@ -559,6 +591,9 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
ctx context.Context,
prevStates []types.StateAtEvent,
) (types.StateSnapshotNID, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateAfterEvents")
+ defer span.Finish()
+
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
if len(prevStates) == 0 {
@@ -631,6 +666,9 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents(
prevStates []types.StateAtEvent,
metrics calculateStateMetrics,
) (types.StateSnapshotNID, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents")
+ defer span.Finish()
+
state, algorithm, conflictLength, err :=
v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
metrics.algorithm = algorithm
@@ -649,6 +687,9 @@ func (v *StateResolution) calculateStateAfterManyEvents(
ctx context.Context, roomVersion gomatrixserverlib.RoomVersion,
prevStates []types.StateAtEvent,
) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateStateAfterManyEvents")
+ defer span.Finish()
+
var combined []types.StateEntry
// Conflict resolution.
// First stage: load the state after each of the prev events.
@@ -701,6 +742,9 @@ func (v *StateResolution) resolveConflicts(
ctx context.Context, version gomatrixserverlib.RoomVersion,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflicts")
+ defer span.Finish()
+
stateResAlgo, err := version.StateResAlgorithm()
if err != nil {
return nil, err
@@ -725,6 +769,8 @@ func (v *StateResolution) resolveConflictsV1(
ctx context.Context,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV1")
+ defer span.Finish()
// Load the conflicted events
conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted)
@@ -788,6 +834,9 @@ func (v *StateResolution) resolveConflictsV2(
ctx context.Context,
notConflicted, conflicted []types.StateEntry,
) ([]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV2")
+ defer span.Finish()
+
estimate := len(conflicted) + len(notConflicted)
eventIDMap := make(map[string]types.StateEntry, estimate)
@@ -815,31 +864,47 @@ func (v *StateResolution) resolveConflictsV2(
authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3)
gotAuthEvents := make(map[string]struct{}, estimate*3)
authDifference := make([]*gomatrixserverlib.Event, 0, estimate)
+ knownAuthEvents := make(map[string]types.Event, estimate*3)
// For each conflicted event, let's try and get the needed auth events.
- for _, conflictedEvent := range conflictedEvents {
- // Work out which auth events we need to load.
- key := conflictedEvent.EventID()
-
- // Store the newly found auth events in the auth set for this event.
- var authEventMap map[string]types.StateEntry
- authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent)
- if err != nil {
- return nil, err
- }
- for k, v := range authEventMap {
- eventIDMap[k] = v
+ if err = func() error {
+ span, sctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadAuthEvents")
+ defer span.Finish()
+
+ loader := authEventLoader{
+ v: v,
+ lookupFromDB: make([]string, 0, len(conflictedEvents)*3),
+ lookupFromMem: make([]string, 0, len(conflictedEvents)*3),
+ lookedUpEvents: make([]types.Event, 0, len(conflictedEvents)*3),
+ eventMap: map[string]types.Event{},
}
+ for _, conflictedEvent := range conflictedEvents {
+ // Work out which auth events we need to load.
+ key := conflictedEvent.EventID()
+
+ // Store the newly found auth events in the auth set for this event.
+ var authEventMap map[string]types.StateEntry
+ authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents)
+ if err != nil {
+ return err
+ }
+ for k, v := range authEventMap {
+ eventIDMap[k] = v
+ }
- // Only add auth events into the authEvents slice once, otherwise the
- // check for the auth difference can become expensive and produce
- // duplicate entries, which just waste memory and CPU time.
- for _, event := range authSets[key] {
- if _, ok := gotAuthEvents[event.EventID()]; !ok {
- authEvents = append(authEvents, event)
- gotAuthEvents[event.EventID()] = struct{}{}
+ // Only add auth events into the authEvents slice once, otherwise the
+ // check for the auth difference can become expensive and produce
+ // duplicate entries, which just waste memory and CPU time.
+ for _, event := range authSets[key] {
+ if _, ok := gotAuthEvents[event.EventID()]; !ok {
+ authEvents = append(authEvents, event)
+ gotAuthEvents[event.EventID()] = struct{}{}
+ }
}
}
+ return nil
+ }(); err != nil {
+ return nil, err
}
// Kill the reference to this so that the GC may pick it up, since we no
@@ -870,19 +935,29 @@ func (v *StateResolution) resolveConflictsV2(
// Look through all of the auth events that we've been given and work out if
// there are any events which don't appear in all of the auth sets. If they
// don't then we add them to the auth difference.
- for _, event := range authEvents {
- if !isInAllAuthLists(event) {
- authDifference = append(authDifference, event)
+ func() {
+ span, _ := opentracing.StartSpanFromContext(ctx, "isInAllAuthLists")
+ defer span.Finish()
+
+ for _, event := range authEvents {
+ if !isInAllAuthLists(event) {
+ authDifference = append(authDifference, event)
+ }
}
- }
+ }()
// Resolve the conflicts.
- resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2(
- conflictedEvents,
- nonConflictedEvents,
- authEvents,
- authDifference,
- )
+ resolvedEvents := func() []*gomatrixserverlib.Event {
+ span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2")
+ defer span.Finish()
+
+ return gomatrixserverlib.ResolveStateConflictsV2(
+ conflictedEvents,
+ nonConflictedEvents,
+ authEvents,
+ authDifference,
+ )
+ }()
// Map from the full events back to numeric state entries.
for _, resolvedEvent := range resolvedEvents {
@@ -947,6 +1022,9 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E
func (v *StateResolution) loadStateEvents(
ctx context.Context, entries []types.StateEntry,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateEvents")
+ defer span.Finish()
+
result := make([]*gomatrixserverlib.Event, 0, len(entries))
eventEntries := make([]types.StateEntry, 0, len(entries))
eventNIDs := make([]types.EventNID, 0, len(entries))
@@ -975,43 +1053,86 @@ func (v *StateResolution) loadStateEvents(
return result, eventIDMap, nil
}
+type authEventLoader struct {
+ sync.Mutex
+ v *StateResolution
+ lookupFromDB []string // scratch space
+ lookupFromMem []string // scratch space
+ lookedUpEvents []types.Event // scratch space
+ eventMap map[string]types.Event
+}
+
// loadAuthEvents loads all of the auth events for a given event recursively,
// along with a map that contains state entries for all of the auth events.
-func (v *StateResolution) loadAuthEvents(
- ctx context.Context, event *gomatrixserverlib.Event,
+func (l *authEventLoader) loadAuthEvents(
+ ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event,
) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) {
- eventMap := map[string]struct{}{}
- var lookup []string
- var authEvents []types.Event
+ l.Lock()
+ defer l.Unlock()
+ authEvents := []types.Event{} // our returned list
+ included := map[string]struct{}{} // dedupes authEvents above
queue := event.AuthEventIDs()
for i := 0; i < len(queue); i++ {
- lookup = lookup[:0]
+ // Reuse the same underlying memory, since it reduces the
+ // amount of allocations we make the more times we call
+ // loadAuthEvents.
+ l.lookupFromDB = l.lookupFromDB[:0]
+ l.lookupFromMem = l.lookupFromMem[:0]
+ l.lookedUpEvents = l.lookedUpEvents[:0]
+
+ // Separate out the list of events in the queue based on if
+ // we think we already know the event in memory or not.
for _, authEventID := range queue {
- if _, ok := eventMap[authEventID]; ok {
+ if _, ok := included[authEventID]; ok {
continue
}
- lookup = append(lookup, authEventID)
+ if _, ok := eventMap[authEventID]; ok {
+ l.lookupFromMem = append(l.lookupFromMem, authEventID)
+ } else {
+ l.lookupFromDB = append(l.lookupFromDB, authEventID)
+ }
}
- if len(lookup) == 0 {
+ // If there's nothing to do, stop here.
+ if len(l.lookupFromDB) == 0 && len(l.lookupFromMem) == 0 {
break
}
- events, err := v.db.EventsFromIDs(ctx, lookup)
- if err != nil {
- return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
+
+ // If we need to get events from the database, go and fetch
+ // those now.
+ if len(l.lookupFromDB) > 0 {
+ eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB)
+ if err != nil {
+ return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err)
+ }
+ l.lookedUpEvents = append(l.lookedUpEvents, eventsFromDB...)
+ for _, event := range eventsFromDB {
+ eventMap[event.EventID()] = event
+ }
+ }
+
+ // Fill in the gaps with events that we already have in memory.
+ if len(l.lookupFromMem) > 0 {
+ for _, eventID := range l.lookupFromMem {
+ l.lookedUpEvents = append(l.lookedUpEvents, eventMap[eventID])
+ }
}
+
+ // From the events that we've retrieved, work out which auth
+ // events to look up on the next iteration.
add := map[string]struct{}{}
- for _, event := range events {
- eventMap[event.EventID()] = struct{}{}
+ for _, event := range l.lookedUpEvents {
authEvents = append(authEvents, event)
+ included[event.EventID()] = struct{}{}
+
for _, authEventID := range event.AuthEventIDs() {
- if _, ok := eventMap[authEventID]; ok {
+ if _, ok := included[authEventID]; ok {
continue
}
add[authEventID] = struct{}{}
}
- for authEventID := range add {
- queue = append(queue, authEventID)
- }
+ }
+ for authEventID := range add {
+ queue = append(queue, authEventID)
}
}
authEventTypes := map[string]struct{}{}
@@ -1028,11 +1149,11 @@ func (v *StateResolution) loadAuthEvents(
for eventStateKey := range authEventStateKeys {
lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey)
}
- eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes)
+ eventTypes, err := l.v.db.EventTypeNIDs(ctx, lookupAuthEventTypes)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err)
}
- eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys)
+ eventStateKeys, err := l.v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys)
if err != nil {
return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err)
}