From ad07b169b8a58b5a843b7b19ff0a989399d0aea0 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 24 Feb 2023 09:40:20 +0100 Subject: Refactor `StoreEvent` and create a new `RoomDatabase` interface (#2985) This PR changes a few things: - It pulls out the creation of several NIDs from the `StoreEvent` function to make the functions more reusable - Uses more caching when using those NIDs to avoid DB round trips --- roomserver/internal/alias.go | 20 -------- roomserver/internal/api.go | 1 - roomserver/internal/helpers/auth.go | 23 ++++----- roomserver/internal/helpers/helpers.go | 22 ++++----- roomserver/internal/helpers/helpers_test.go | 13 +++++- roomserver/internal/input/input.go | 2 +- roomserver/internal/input/input_events.go | 50 ++++++++++++++++---- roomserver/internal/input/input_membership.go | 2 +- roomserver/internal/input/input_missing.go | 10 ++-- roomserver/internal/perform/perform_admin.go | 2 +- roomserver/internal/perform/perform_backfill.go | 54 +++++++++++++++------- .../internal/perform/perform_inbound_peek.go | 6 +-- roomserver/internal/perform/perform_invite.go | 4 +- roomserver/internal/perform/perform_unpeek.go | 5 +- roomserver/internal/query/query.go | 40 ++++++---------- roomserver/internal/query/query_test.go | 2 +- 16 files changed, 139 insertions(+), 117 deletions(-) (limited to 'roomserver/internal') diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 329e6af7..fc61b7f4 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -30,26 +30,6 @@ import ( "github.com/tidwall/sjson" ) -// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API. -type RoomserverInternalAPIDatabase interface { - // Save a given room alias with the room ID it refers to. - // Returns an error if there was a problem talking to the database. - SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error - // Look up the room ID a given alias refers to. - // Returns an error if there was a problem talking to the database. - GetRoomIDForAlias(ctx context.Context, alias string) (string, error) - // Look up all aliases referring to a given room ID. - // Returns an error if there was a problem talking to the database. - GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) - // Remove a given room alias. - // Returns an error if there was a problem talking to the database. - RemoveRoomAlias(ctx context.Context, alias string) error - // Look up the room version for a given room. - GetRoomVersionForRoom( - ctx context.Context, roomID string, - ) (gomatrixserverlib.RoomVersion, error) -} - // SetRoomAlias implements alias.RoomserverInternalAPI func (r *RoomserverInternalAPI) SetRoomAlias( ctx context.Context, diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 451b3769..c43b9d04 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -155,7 +155,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.Unpeeker = &perform.Unpeeker{ ServerName: r.ServerName, Cfg: r.Cfg, - DB: r.DB, FSAPI: r.fsAPI, Inputer: r.Inputer, } diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 03d8bca0..27c8dd8f 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -31,7 +31,8 @@ import ( // the soft-fail bool. func CheckForSoftFail( ctx context.Context, - db storage.Database, + db storage.RoomDatabase, + roomInfo *types.RoomInfo, event *gomatrixserverlib.HeaderedEvent, stateEventIDs []string, ) (bool, error) { @@ -45,16 +46,6 @@ func CheckForSoftFail( return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err) } } else { - // Work out if the room exists. - var roomInfo *types.RoomInfo - roomInfo, err = db.RoomInfo(ctx, event.RoomID()) - if err != nil { - return false, fmt.Errorf("db.RoomNID: %w", err) - } - if roomInfo == nil || roomInfo.IsStub() { - return false, nil - } - // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. roomState := state.NewStateResolution(db, roomInfo) @@ -76,7 +67,7 @@ func CheckForSoftFail( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries) if err != nil { return true, fmt.Errorf("loadAuthEvents: %w", err) } @@ -93,7 +84,8 @@ func CheckForSoftFail( // Returns the numeric IDs for the auth events. func CheckAuthEvents( ctx context.Context, - db storage.Database, + db storage.RoomDatabase, + roomNID types.RoomNID, event *gomatrixserverlib.HeaderedEvent, authEventIDs []string, ) ([]types.EventNID, error) { @@ -108,7 +100,7 @@ func CheckAuthEvents( stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()}) // Load the actual auth events from the database. - authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries) if err != nil { return nil, fmt.Errorf("loadAuthEvents: %w", err) } @@ -201,6 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) * func loadAuthEvents( ctx context.Context, db state.StateResolutionStorage, + roomNID types.RoomNID, needed gomatrixserverlib.StateNeeded, state []types.StateEntry, ) (result authEvents, err error) { @@ -223,7 +216,7 @@ func loadAuthEvents( eventNIDs = append(eventNIDs, eventNID) } } - if result.events, err = db.Events(ctx, eventNIDs); err != nil { + if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil { return } roomID := "" diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 7efad7af..ee1610cf 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam return false, err } - events, err := db.Events(ctx, eventNIDs) + events, err := db.Events(ctx, info.RoomNID, eventNIDs) if err != nil { return false, err } @@ -157,7 +157,7 @@ func IsInvitePending( // only keep the "m.room.member" events with a "join" membership. These events are returned. // Returns an error if there was an issue fetching the events. func GetMembershipsAtState( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, + ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool, ) ([]types.Event, error) { var eventNIDs types.EventNIDs @@ -177,7 +177,7 @@ func GetMembershipsAtState( util.Unique(eventNIDs) // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomNID, eventNIDs) if err != nil { return nil, err } @@ -220,16 +220,16 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room return roomState.LoadCombinedStateAfterEvents(ctx, prevState) } -func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { +func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) { roomState := state.NewStateResolution(db, info) // Fetch the state as it was when this event was fired return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID) } func LoadEvents( - ctx context.Context, db storage.Database, eventNIDs []types.EventNID, + ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID, ) ([]*gomatrixserverlib.Event, error) { - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomNID, eventNIDs) if err != nil { return nil, err } @@ -242,13 +242,13 @@ func LoadEvents( } func LoadStateEvents( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, ) ([]*gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } - return LoadEvents(ctx, db, eventNIDs) + return LoadEvents(ctx, db, roomNID, eventNIDs) } func CheckServerAllowedToSeeEvent( @@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState( return nil, nil } - return LoadStateEvents(ctx, db, filteredEntries) + return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries) } // TODO: Remove this when we have tests to assert correctness of this function @@ -366,7 +366,7 @@ BFSLoop: next = make([]string, 0) } // Retrieve the events to process from the database. - events, err = db.EventsFromIDs(ctx, front) + events, err = db.EventsFromIDs(ctx, info.RoomNID, front) if err != nil { return resultNIDs, redactEventIDs, err } @@ -467,7 +467,7 @@ func QueryLatestEventsAndState( return err } - stateEvents, err := LoadStateEvents(ctx, db, stateEntries) + stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries) if err != nil { return err } diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go index aa5c30e4..62730df1 100644 --- a/roomserver/internal/helpers/helpers_test.go +++ b/roomserver/internal/helpers/helpers_test.go @@ -38,7 +38,18 @@ func TestIsInvitePendingWithoutNID(t *testing.T) { var authNIDs []types.EventNID for _, x := range room.Events() { - evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false) + roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap()) + assert.NoError(t, err) + assert.Greater(t, roomNID, types.RoomNID(0)) + + eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type()) + assert.NoError(t, err) + assert.Greater(t, eventTypeNID, types.EventTypeNID(0)) + + eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey()) + assert.NoError(t, err) + + evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false) assert.NoError(t, err) authNIDs = append(authNIDs, evNID) } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 94131103..2ec19f01 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -76,7 +76,7 @@ type Inputer struct { Cfg *config.RoomServer Base *base.BaseDendrite ProcessContext *process.ProcessContext - DB storage.Database + DB storage.RoomDatabase NATSClient *nats.Conn JetStream nats.JetStreamContext Durable nats.SubOpt diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 67edb321..fe35efb2 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -308,10 +308,10 @@ func (r *Inputer) processRoomEvent( } var softfail bool - if input.Kind == api.KindNew { + if input.Kind == api.KindNew && !isCreateEvent { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs) if err != nil { logger.WithError(err).Warn("Error authing soft-failed event") } @@ -322,8 +322,8 @@ func (r *Inputer) processRoomEvent( // bother doing this if the event was already rejected as it just ends up // burning CPU time. historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared. - if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected { - historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev) + if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent { + historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev) if err != nil { return fmt.Errorf("r.processStateBefore: %w", err) } @@ -332,8 +332,23 @@ func (r *Inputer) processRoomEvent( } } + roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + } + + eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err) + } + + eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err) + } + // Store the event. - _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -474,6 +489,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse // nolint:nakedret func (r *Inputer) processStateBefore( ctx context.Context, + roomNID types.RoomNID, input *api.InputRoomEvent, missingPrev bool, ) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) { @@ -489,7 +505,7 @@ func (r *Inputer) processStateBefore( case input.HasState: // If we're overriding the state then we need to go and retrieve // them from the database. It's a hard error if they are missing. - stateEvents, err := r.DB.EventsFromIDs(ctx, input.StateEventIDs) + stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs) if err != nil { return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err) } @@ -567,6 +583,7 @@ func (r *Inputer) processStateBefore( // we've failed to retrieve the auth chain altogether (in which case // an error is returned) or we've successfully retrieved them all and // they are now in the database. +// nolint: gocyclo func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, @@ -587,7 +604,7 @@ func (r *Inputer) fetchAuthEvents( } for _, authEventID := range authEventIDs { - authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID}) + authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID}) if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil { unknown[authEventID] = struct{}{} continue @@ -673,8 +690,23 @@ nextAuthEvent: logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } + roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err) + } + + eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err) + } + + eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey()) + if err != nil { + return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err) + } + // Finally, store the event in the database. - eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected) + eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected) if err != nil { return fmt.Errorf("updater.StoreEvent: %w", err) } @@ -750,7 +782,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event return err } - memberEvents, err := r.DB.Events(ctx, membershipNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs) if err != nil { return err } diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go index 28a54623..99a01255 100644 --- a/roomserver/internal/input/input_membership.go +++ b/roomserver/internal/input/input_membership.go @@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships( // Load the event JSON so we can look up the "membership" key. // TODO: Maybe add a membership key to the events table so we can load that // key without having to load the entire event JSON? - events, err := updater.Events(ctx, eventNIDs) + events, err := updater.Events(ctx, 0, eventNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 03ac2b38..c8b7d31d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -43,7 +43,7 @@ type missingStateReq struct { log *logrus.Entry virtualHost gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName - db storage.Database + db storage.RoomDatabase roomInfo *types.RoomInfo inputer *Inputer keys gomatrixserverlib.JSONVerifier @@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even for _, entry := range stateEntries { stateEventNIDs = append(stateEventNIDs, entry.EventNID) } - stateEvents, err := t.db.Events(ctx, stateEventNIDs) + stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs) if err != nil { t.log.WithError(err).Warnf("failed to load state events locally") return nil @@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even missingEventList = append(missingEventList, evID) } t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") - events, err := t.db.EventsFromIDs(ctx, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) if err != nil { return nil } @@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - events, err := t.db.EventsFromIDs(ctx, missingEventList) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList) if err != nil { return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } @@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs if localFirst { // fetch from the roomserver - events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) + events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID}) if err != nil { t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) } else if len(events) == 1 { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 3256162b..2efe2255 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - memberEvents, err := r.DB.Events(ctx, memberNIDs) + memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index d9214fdc..3a3a049d 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -86,7 +86,7 @@ func (r *Backfiller) PerformBackfill( // Retrieve events from the list that was filled previously. If we fail to get // events from the database then attempt once to get them from federation instead. var loadedEvents []*gomatrixserverlib.Event - loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs) + loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) if err != nil { if _, ok := err.(types.MissingEventError); ok { return r.backfillViaFederation(ctx, request, response) @@ -258,6 +258,7 @@ type backfillRequester struct { eventIDToBeforeStateIDs map[string][]string eventIDMap map[string]*gomatrixserverlib.Event historyVisiblity gomatrixserverlib.HistoryVisibility + roomInfo types.RoomInfo } func newBackfillRequester( @@ -454,14 +455,14 @@ FindSuccessor: return nil } - stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID]) + stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -472,7 +473,7 @@ FindSuccessor: // Retrieve all "m.room.member" state events of "join" membership, which // contains the list of users in the room before the event, therefore all // the servers in it at that moment. - memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true) + memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") return nil @@ -523,11 +524,15 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, } eventNIDs := make([]types.EventNID, len(nidMap)) i := 0 + roomNID := b.roomInfo.RoomNID for _, nid := range nidMap { - eventNIDs[i] = nid + eventNIDs[i] = nid.EventNID i++ + if roomNID == 0 { + roomNID = nid.RoomNID + } } - eventsWithNids, err := b.db.Events(ctx, eventNIDs) + eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs) if err != nil { logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events") return nil, err @@ -544,7 +549,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // pull all events and then filter by that table. func joinEventsFromHistoryVisibility( - ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry, + ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry, thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) { var eventNIDs []types.EventNID @@ -557,7 +562,7 @@ func joinEventsFromHistoryVisibility( } // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) + stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs) if err != nil { // even though the default should be shared, restricting the visibility to joined // feels more secure here. @@ -570,21 +575,17 @@ func joinEventsFromHistoryVisibility( // Can we see events in the room? canSeeEvents := auth.IsServerAllowed(thisServer, true, events) - visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events)) + visibility := auth.HistoryVisibilityForRoom(events) if !canSeeEvents { logrus.Infof("ServersAtEvent history not visible to us: %s", visibility) return nil, visibility, nil } // get joined members - info, err := db.RoomInfo(ctx, roomID) - if err != nil { - return nil, visibility, nil - } - joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, false) if err != nil { return nil, visibility, err } - evs, err := db.Events(ctx, joinEventNIDs) + evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs) return evs, visibility, err } @@ -601,12 +602,31 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs authNids := make([]types.EventNID, len(nidMap)) i := 0 for _, nid := range nidMap { - authNids[i] = nid + authNids[i] = nid.EventNID i++ } + + roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap()) + if err != nil { + logrus.WithError(err).Error("failed to get or create roomNID") + continue + } + + eventTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type()) + if err != nil { + logrus.WithError(err).Error("failed to get or create eventType NID") + continue + } + + eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(ctx, ev.StateKey()) + if err != nil { + logrus.WithError(err).Error("failed to get or create eventStateKey NID") + continue + } + var redactedEventID string var redactionEvent *gomatrixserverlib.Event - eventNID, roomNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), authNids, false) + eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index 29decd36..9ac9edc4 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -29,7 +29,7 @@ import ( ) type InboundPeeker struct { - DB storage.Database + DB storage.RoomDatabase Inputer *input.Inputer } @@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - latestEvents, err := r.DB.EventsFromIDs(ctx, []string{latestEventRefs[0].EventID}) + latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID}) if err != nil { return err } @@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek( if err != nil { return err } - stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, stateEntries) + stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) if err != nil { return err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index f60247cd..118e1b87 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite( // try and see if the user is allowed to make this invite. We can't do // this for invites coming in over federation - we have to take those on // trust. - _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs()) + _, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs()) if err != nil { logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", @@ -291,7 +291,7 @@ func buildInviteStrippedState( for _, stateNID := range stateEntries { stateNIDs = append(stateNIDs, stateNID.EventNID) } - stateEvents, err := db.Events(ctx, stateNIDs) + stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs) if err != nil { return nil, err } diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go index 0d97da4d..4d714be6 100644 --- a/roomserver/internal/perform/perform_unpeek.go +++ b/roomserver/internal/perform/perform_unpeek.go @@ -22,7 +22,6 @@ import ( fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/input" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" ) @@ -31,9 +30,7 @@ type Unpeeker struct { ServerName gomatrixserverlib.ServerName Cfg *config.RoomServer FSAPI fsAPI.RoomserverFederationAPI - DB storage.Database - - Inputer *input.Inputer + Inputer *input.Inputer } // PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi. diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 1083bb23..ac34e0ff 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -102,7 +102,7 @@ func (r *Queryer) QueryStateAfterEvents( return err } - stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries) if err != nil { return err } @@ -138,17 +138,7 @@ func (r *Queryer) QueryEventsByID( request *api.QueryEventsByIDRequest, response *api.QueryEventsByIDResponse, ) error { - eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs) - if err != nil { - return err - } - - var eventNIDs []types.EventNID - for _, nid := range eventNIDMap { - eventNIDs = append(eventNIDs, nid) - } - - events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs) + events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs) if err != nil { return err } @@ -196,7 +186,7 @@ func (r *Queryer) QueryMembershipForUser( response.IsInRoom = stillInRoom response.HasBeenInRoom = true - evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID}) + evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID}) if err != nil { return err } @@ -278,10 +268,10 @@ func (r *Queryer) QueryMembershipAtEvent( // once. If we have more than one membership event, we need to get the state for each state entry. if canShortCircuit { if len(memberships) == 0 { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) } } else { - memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false) + memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false) } if err != nil { return fmt.Errorf("unable to get memberships at state: %w", err) @@ -328,7 +318,7 @@ func (r *Queryer) QueryMembershipsForRoom( } return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) } - events, err = r.DB.Events(ctx, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) if err != nil { return fmt.Errorf("r.DB.Events: %w", err) } @@ -367,14 +357,14 @@ func (r *Queryer) QueryMembershipsForRoom( return err } - events, err = r.DB.Events(ctx, eventNIDs) + events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs) } else { stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err } - events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) + events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly) } if err != nil { @@ -425,7 +415,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( request *api.QueryServerAllowedToSeeEventRequest, response *api.QueryServerAllowedToSeeEventResponse, ) (err error) { - events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID}) + events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID}) if err != nil { return } @@ -476,7 +466,7 @@ func (r *Queryer) QueryMissingEvents( eventsToFilter[id] = true } } - events, err := r.DB.EventsFromIDs(ctx, front) + events, err := r.DB.EventsFromIDs(ctx, 0, front) if err != nil { return err } @@ -496,7 +486,7 @@ func (r *Queryer) QueryMissingEvents( return err } - loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs) + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs) if err != nil { return err } @@ -621,11 +611,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI return nil, rejected, false, err } - events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) + events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries) return events, rejected, false, err } -type eventsFromIDs func(context.Context, []string) ([]types.Event, error) +type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error) // GetAuthChain fetches the auth chain for the given auth events. An auth chain // is the list of all events that are referenced in the auth_events section, and @@ -643,7 +633,7 @@ func GetAuthChain( for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. - events, err := fn(ctx, eventsToFetch) + events, err := fn(ctx, 0, eventsToFetch) if err != nil { return nil, err } @@ -981,7 +971,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query // For each of the joined users, let's see if we can get a valid // membership event. for _, joinNID := range joinNIDs { - events, err := r.DB.Events(ctx, []types.EventNID{joinNID}) + events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID}) if err != nil || len(events) != 1 { continue } diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go index 03627ea9..16761157 100644 --- a/roomserver/internal/query/query_test.go +++ b/roomserver/internal/query/query_test.go @@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error { } // EventsFromIDs implements RoomserverInternalAPIEventDB -func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) { +func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) { for _, evID := range eventIDs { res = append(res, types.Event{ EventNID: 0, -- cgit v1.2.3