diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-06-07 14:23:26 +0100 |
---|---|---|
committer | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-06-07 14:23:26 +0100 |
commit | 27948fb30468315ce613402dc8cc1fa7dba01679 (patch) | |
tree | c619df228460d99aee4d19e629d483fb9e3db002 /roomserver/state | |
parent | aafb7bf120d30c37219686a5bb528794b0ab44a2 (diff) |
Optimise `loadAuthEvents`, add roomserver tracing
Diffstat (limited to 'roomserver/state')
-rw-r--r-- | roomserver/state/state.go | 219 |
1 files changed, 170 insertions, 49 deletions
diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 95abdcb3..6c4e4b86 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -20,9 +20,11 @@ import ( "context" "fmt" "sort" + "sync" "time" "github.com/matrix-org/util" + "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "github.com/matrix-org/dendrite/roomserver/types" @@ -62,6 +64,9 @@ func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) Sta func (v *StateResolution) LoadStateAtSnapshot( ctx context.Context, stateNID types.StateSnapshotNID, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshot") + defer span.Finish() + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { return nil, err @@ -100,6 +105,9 @@ func (v *StateResolution) LoadStateAtSnapshot( func (v *StateResolution) LoadStateAtEvent( ctx context.Context, eventID string, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtEvent") + defer span.Finish() + snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) @@ -122,6 +130,9 @@ func (v *StateResolution) LoadStateAtEvent( func (v *StateResolution) LoadCombinedStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadCombinedStateAfterEvents") + defer span.Finish() + stateNIDs := make([]types.StateSnapshotNID, len(prevStates)) for i, state := range prevStates { stateNIDs[i] = state.BeforeStateSnapshotNID @@ -194,6 +205,9 @@ func (v *StateResolution) LoadCombinedStateAfterEvents( func (v *StateResolution) DifferenceBetweeenStateSnapshots( ctx context.Context, oldStateNID, newStateNID types.StateSnapshotNID, ) (removed, added []types.StateEntry, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.DifferenceBetweeenStateSnapshots") + defer span.Finish() + if oldStateNID == newStateNID { // If the snapshot NIDs are the same then nothing has changed return nil, nil, nil @@ -255,6 +269,9 @@ func (v *StateResolution) LoadStateAtSnapshotForStringTuples( stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAtSnapshotForStringTuples") + defer span.Finish() + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { return nil, err @@ -269,6 +286,9 @@ func (v *StateResolution) stringTuplesToNumericTuples( ctx context.Context, stringTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateKeyTuple, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.stringTuplesToNumericTuples") + defer span.Finish() + eventTypes := make([]string, len(stringTuples)) stateKeys := make([]string, len(stringTuples)) for i := range stringTuples { @@ -311,6 +331,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples( stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAtSnapshotForNumericTuples") + defer span.Finish() + stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) if err != nil { return nil, err @@ -359,6 +382,9 @@ func (v *StateResolution) LoadStateAfterEventsForStringTuples( prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.LoadStateAfterEventsForStringTuples") + defer span.Finish() + numericTuples, err := v.stringTuplesToNumericTuples(ctx, stateKeyTuples) if err != nil { return nil, err @@ -371,6 +397,9 @@ func (v *StateResolution) loadStateAfterEventsForNumericTuples( prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateAfterEventsForNumericTuples") + defer span.Finish() + if len(prevStates) == 1 { // Fast path for a single event. prevState := prevStates[0] @@ -543,6 +572,9 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent( event *gomatrixserverlib.Event, isRejected bool, ) (types.StateSnapshotNID, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateBeforeEvent") + defer span.Finish() + // Load the state at the prev events. prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) if err != nil { @@ -559,6 +591,9 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( ctx context.Context, prevStates []types.StateAtEvent, ) (types.StateSnapshotNID, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.CalculateAndStoreStateAfterEvents") + defer span.Finish() + metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)} if len(prevStates) == 0 { @@ -631,6 +666,9 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents( prevStates []types.StateAtEvent, metrics calculateStateMetrics, ) (types.StateSnapshotNID, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateAndStoreStateAfterManyEvents") + defer span.Finish() + state, algorithm, conflictLength, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates) metrics.algorithm = algorithm @@ -649,6 +687,9 @@ func (v *StateResolution) calculateStateAfterManyEvents( ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, prevStates []types.StateAtEvent, ) (state []types.StateEntry, algorithm string, conflictLength int, err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.calculateStateAfterManyEvents") + defer span.Finish() + var combined []types.StateEntry // Conflict resolution. // First stage: load the state after each of the prev events. @@ -701,6 +742,9 @@ func (v *StateResolution) resolveConflicts( ctx context.Context, version gomatrixserverlib.RoomVersion, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflicts") + defer span.Finish() + stateResAlgo, err := version.StateResAlgorithm() if err != nil { return nil, err @@ -725,6 +769,8 @@ func (v *StateResolution) resolveConflictsV1( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV1") + defer span.Finish() // Load the conflicted events conflictedEvents, eventIDMap, err := v.loadStateEvents(ctx, conflicted) @@ -788,6 +834,9 @@ func (v *StateResolution) resolveConflictsV2( ctx context.Context, notConflicted, conflicted []types.StateEntry, ) ([]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.resolveConflictsV2") + defer span.Finish() + estimate := len(conflicted) + len(notConflicted) eventIDMap := make(map[string]types.StateEntry, estimate) @@ -815,31 +864,47 @@ func (v *StateResolution) resolveConflictsV2( authEvents := make([]*gomatrixserverlib.Event, 0, estimate*3) gotAuthEvents := make(map[string]struct{}, estimate*3) authDifference := make([]*gomatrixserverlib.Event, 0, estimate) + knownAuthEvents := make(map[string]types.Event, estimate*3) // For each conflicted event, let's try and get the needed auth events. - for _, conflictedEvent := range conflictedEvents { - // Work out which auth events we need to load. - key := conflictedEvent.EventID() - - // Store the newly found auth events in the auth set for this event. - var authEventMap map[string]types.StateEntry - authSets[key], authEventMap, err = v.loadAuthEvents(ctx, conflictedEvent) - if err != nil { - return nil, err - } - for k, v := range authEventMap { - eventIDMap[k] = v + if err = func() error { + span, sctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadAuthEvents") + defer span.Finish() + + loader := authEventLoader{ + v: v, + lookupFromDB: make([]string, 0, len(conflictedEvents)*3), + lookupFromMem: make([]string, 0, len(conflictedEvents)*3), + lookedUpEvents: make([]types.Event, 0, len(conflictedEvents)*3), + eventMap: map[string]types.Event{}, } + for _, conflictedEvent := range conflictedEvents { + // Work out which auth events we need to load. + key := conflictedEvent.EventID() + + // Store the newly found auth events in the auth set for this event. + var authEventMap map[string]types.StateEntry + authSets[key], authEventMap, err = loader.loadAuthEvents(sctx, conflictedEvent, knownAuthEvents) + if err != nil { + return err + } + for k, v := range authEventMap { + eventIDMap[k] = v + } - // Only add auth events into the authEvents slice once, otherwise the - // check for the auth difference can become expensive and produce - // duplicate entries, which just waste memory and CPU time. - for _, event := range authSets[key] { - if _, ok := gotAuthEvents[event.EventID()]; !ok { - authEvents = append(authEvents, event) - gotAuthEvents[event.EventID()] = struct{}{} + // Only add auth events into the authEvents slice once, otherwise the + // check for the auth difference can become expensive and produce + // duplicate entries, which just waste memory and CPU time. + for _, event := range authSets[key] { + if _, ok := gotAuthEvents[event.EventID()]; !ok { + authEvents = append(authEvents, event) + gotAuthEvents[event.EventID()] = struct{}{} + } } } + return nil + }(); err != nil { + return nil, err } // Kill the reference to this so that the GC may pick it up, since we no @@ -870,19 +935,29 @@ func (v *StateResolution) resolveConflictsV2( // Look through all of the auth events that we've been given and work out if // there are any events which don't appear in all of the auth sets. If they // don't then we add them to the auth difference. - for _, event := range authEvents { - if !isInAllAuthLists(event) { - authDifference = append(authDifference, event) + func() { + span, _ := opentracing.StartSpanFromContext(ctx, "isInAllAuthLists") + defer span.Finish() + + for _, event := range authEvents { + if !isInAllAuthLists(event) { + authDifference = append(authDifference, event) + } } - } + }() // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflictsV2( - conflictedEvents, - nonConflictedEvents, - authEvents, - authDifference, - ) + resolvedEvents := func() []*gomatrixserverlib.Event { + span, _ := opentracing.StartSpanFromContext(ctx, "gomatrixserverlib.ResolveStateConflictsV2") + defer span.Finish() + + return gomatrixserverlib.ResolveStateConflictsV2( + conflictedEvents, + nonConflictedEvents, + authEvents, + authDifference, + ) + }() // Map from the full events back to numeric state entries. for _, resolvedEvent := range resolvedEvents { @@ -947,6 +1022,9 @@ func (v *StateResolution) stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.E func (v *StateResolution) loadStateEvents( ctx context.Context, entries []types.StateEntry, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "StateResolution.loadStateEvents") + defer span.Finish() + result := make([]*gomatrixserverlib.Event, 0, len(entries)) eventEntries := make([]types.StateEntry, 0, len(entries)) eventNIDs := make([]types.EventNID, 0, len(entries)) @@ -975,43 +1053,86 @@ func (v *StateResolution) loadStateEvents( return result, eventIDMap, nil } +type authEventLoader struct { + sync.Mutex + v *StateResolution + lookupFromDB []string // scratch space + lookupFromMem []string // scratch space + lookedUpEvents []types.Event // scratch space + eventMap map[string]types.Event +} + // loadAuthEvents loads all of the auth events for a given event recursively, // along with a map that contains state entries for all of the auth events. -func (v *StateResolution) loadAuthEvents( - ctx context.Context, event *gomatrixserverlib.Event, +func (l *authEventLoader) loadAuthEvents( + ctx context.Context, event *gomatrixserverlib.Event, eventMap map[string]types.Event, ) ([]*gomatrixserverlib.Event, map[string]types.StateEntry, error) { - eventMap := map[string]struct{}{} - var lookup []string - var authEvents []types.Event + l.Lock() + defer l.Unlock() + authEvents := []types.Event{} // our returned list + included := map[string]struct{}{} // dedupes authEvents above queue := event.AuthEventIDs() for i := 0; i < len(queue); i++ { - lookup = lookup[:0] + // Reuse the same underlying memory, since it reduces the + // amount of allocations we make the more times we call + // loadAuthEvents. + l.lookupFromDB = l.lookupFromDB[:0] + l.lookupFromMem = l.lookupFromMem[:0] + l.lookedUpEvents = l.lookedUpEvents[:0] + + // Separate out the list of events in the queue based on if + // we think we already know the event in memory or not. for _, authEventID := range queue { - if _, ok := eventMap[authEventID]; ok { + if _, ok := included[authEventID]; ok { continue } - lookup = append(lookup, authEventID) + if _, ok := eventMap[authEventID]; ok { + l.lookupFromMem = append(l.lookupFromMem, authEventID) + } else { + l.lookupFromDB = append(l.lookupFromDB, authEventID) + } } - if len(lookup) == 0 { + // If there's nothing to do, stop here. + if len(l.lookupFromDB) == 0 && len(l.lookupFromMem) == 0 { break } - events, err := v.db.EventsFromIDs(ctx, lookup) - if err != nil { - return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) + + // If we need to get events from the database, go and fetch + // those now. + if len(l.lookupFromDB) > 0 { + eventsFromDB, err := l.v.db.EventsFromIDs(ctx, l.lookupFromDB) + if err != nil { + return nil, nil, fmt.Errorf("v.db.EventsFromIDs: %w", err) + } + l.lookedUpEvents = append(l.lookedUpEvents, eventsFromDB...) + for _, event := range eventsFromDB { + eventMap[event.EventID()] = event + } + } + + // Fill in the gaps with events that we already have in memory. + if len(l.lookupFromMem) > 0 { + for _, eventID := range l.lookupFromMem { + l.lookedUpEvents = append(l.lookedUpEvents, eventMap[eventID]) + } } + + // From the events that we've retrieved, work out which auth + // events to look up on the next iteration. add := map[string]struct{}{} - for _, event := range events { - eventMap[event.EventID()] = struct{}{} + for _, event := range l.lookedUpEvents { authEvents = append(authEvents, event) + included[event.EventID()] = struct{}{} + for _, authEventID := range event.AuthEventIDs() { - if _, ok := eventMap[authEventID]; ok { + if _, ok := included[authEventID]; ok { continue } add[authEventID] = struct{}{} } - for authEventID := range add { - queue = append(queue, authEventID) - } + } + for authEventID := range add { + queue = append(queue, authEventID) } } authEventTypes := map[string]struct{}{} @@ -1028,11 +1149,11 @@ func (v *StateResolution) loadAuthEvents( for eventStateKey := range authEventStateKeys { lookupAuthEventStateKeys = append(lookupAuthEventStateKeys, eventStateKey) } - eventTypes, err := v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) + eventTypes, err := l.v.db.EventTypeNIDs(ctx, lookupAuthEventTypes) if err != nil { return nil, nil, fmt.Errorf("v.db.EventTypeNIDs: %w", err) } - eventStateKeys, err := v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) + eventStateKeys, err := l.v.db.EventStateKeyNIDs(ctx, lookupAuthEventStateKeys) if err != nil { return nil, nil, fmt.Errorf("v.db.EventStateKeyNIDs: %w", err) } |