From 02a73f29f861c637f30df4a2bb1fce400e481a3c Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 2 Sep 2020 10:02:48 +0100 Subject: Expand RoomInfo to cover more DB storage functions (#1377) * Factor more things to RoomInfo * Factor out remaining bits for RoomInfo * Linting for now --- roomserver/internal/alias.go | 10 ++- roomserver/internal/input_events.go | 24 ++++-- roomserver/internal/input_latest_events.go | 23 ++---- roomserver/internal/perform_backfill.go | 13 ++- roomserver/internal/perform_invite.go | 2 +- roomserver/internal/query.go | 128 +++++++++++++++++------------ 6 files changed, 122 insertions(+), 78 deletions(-) (limited to 'roomserver/internal') diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 4139582b..d576a817 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "time" "github.com/matrix-org/dendrite/roomserver/api" @@ -239,16 +240,19 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( } builder.AuthEvents = refs - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomID) + roomInfo, err := r.DB.RoomInfo(ctx, roomID) if err != nil { return err } + if roomInfo == nil { + return fmt.Errorf("room %s does not exist", roomID) + } // Build the event now := time.Now() event, err := builder.Build( now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID, - r.Cfg.Matrix.PrivateKey, roomVersion, + r.Cfg.Matrix.PrivateKey, roomInfo.RoomVersion, ) if err != nil { return err @@ -257,7 +261,7 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( // Create the request ire := api.InputRoomEvent{ Kind: api.KindNew, - Event: event.Headered(roomVersion), + Event: event.Headered(roomInfo.RoomVersion), AuthEventIDs: event.AuthEventIDs(), SendAsServer: serverName, } diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go index a6308299..287db1af 100644 --- a/roomserver/internal/input_events.go +++ b/roomserver/internal/input_events.go @@ -64,7 +64,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( } // Store the event. - roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } @@ -89,10 +89,18 @@ func (r *RoomserverInternalAPI) processRoomEvent( return event.EventID(), nil } + roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID()) + if err != nil { + return "", fmt.Errorf("r.DB.RoomInfo: %w", err) + } + if roomInfo == nil { + return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID()) + } + if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event) + err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) if err != nil { return "", fmt.Errorf("r.calculateAndSetState: %w", err) } @@ -100,7 +108,7 @@ func (r *RoomserverInternalAPI) processRoomEvent( if err = r.updateLatestEvents( ctx, // context - roomNID, // room NID to update + roomInfo, // room info for the room being updated stateAtEvent, // state at event (below) event, // event input.SendAsServer, // send as server @@ -135,19 +143,19 @@ func (r *RoomserverInternalAPI) processRoomEvent( func (r *RoomserverInternalAPI) calculateAndSetState( ctx context.Context, input api.InputRoomEvent, - roomNID types.RoomNID, + roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, ) error { var err error - roomState := state.NewStateResolution(r.DB) + roomState := state.NewStateResolution(r.DB, roomInfo) if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID // Request join memberships only for local users only. - if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil { + if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil { // If we have no local users that are joined to the room then any state about // the room that we have is quite possibly out of date. Therefore in that case // we should overwrite it rather than merge it. @@ -161,14 +169,14 @@ func (r *RoomserverInternalAPI) calculateAndSetState( return err } - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { return err } } else { stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { return err } } diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go index f11a78d7..d5e38e7a 100644 --- a/roomserver/internal/input_latest_events.go +++ b/roomserver/internal/input_latest_events.go @@ -49,13 +49,13 @@ import ( // Can only be called once at a time func (r *RoomserverInternalAPI) updateLatestEvents( ctx context.Context, - roomNID types.RoomNID, + roomInfo *types.RoomInfo, stateAtEvent types.StateAtEvent, event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, ) (err error) { - updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID) + updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) if err != nil { return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err) } @@ -66,7 +66,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents( ctx: ctx, api: r, updater: updater, - roomNID: roomNID, + roomInfo: roomInfo, stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer, @@ -89,7 +89,7 @@ type latestEventsUpdater struct { ctx context.Context api *RoomserverInternalAPI updater *shared.LatestEventsUpdater - roomNID types.RoomNID + roomInfo *types.RoomInfo stateAtEvent types.StateAtEvent event gomatrixserverlib.Event transactionID *api.TransactionID @@ -196,7 +196,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { return fmt.Errorf("u.api.WriteOutputEvents: %w", err) } - if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { + if err = u.updater.SetLatestEvents(u.roomInfo.RoomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil { return fmt.Errorf("u.updater.SetLatestEvents: %w", err) } @@ -209,7 +209,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB) + roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) // Get a list of the current latest events. latestStateAtEvents := make([]types.StateAtEvent, len(u.latest)) @@ -221,7 +221,7 @@ func (u *latestEventsUpdater) latestState() error { // of the state after the events. The snapshot state will be resolved // using the correct state resolution algorithm for the room. u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents( - u.ctx, u.roomNID, latestStateAtEvents, + u.ctx, latestStateAtEvents, ) if err != nil { return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err) @@ -303,13 +303,8 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) latestEventIDs[i] = u.latest[i].EventID } - roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID()) - if err != nil { - return nil, err - } - ore := api.OutputNewRoomEvent{ - Event: u.event.Headered(roomVersion), + Event: u.event.Headered(u.roomInfo.RoomVersion), LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, TransactionID: u.transactionID, @@ -337,7 +332,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) // include extra state events if they were added as nearly every downstream component will care about it // and we'd rather not have them all hit QueryEventsByID at the same time! if len(ore.AddsStateEventIDs) > 0 { - ore.AddStateEvents, err = u.extraEventsForIDs(roomVersion, ore.AddsStateEventIDs) + ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs) if err != nil { return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) } diff --git a/roomserver/internal/perform_backfill.go b/roomserver/internal/perform_backfill.go index 65c88860..721f6610 100644 --- a/roomserver/internal/perform_backfill.go +++ b/roomserver/internal/perform_backfill.go @@ -162,6 +162,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr // It returns a list of servers which can be queried for backfill requests. These servers // will be servers that are in the room already. The entries at the beginning are preferred servers // and will be tried first. An empty list will fail the request. +// nolint:gocyclo func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName { // eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use // its successor, so look it up. @@ -189,7 +190,17 @@ FindSuccessor: return nil } - stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) + info, err := b.db.RoomInfo(ctx, roomID) + if err != nil { + logrus.WithError(err).WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room") + return nil + } + if info == nil || info.IsStub { + logrus.WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room, room is missing") + return nil + } + + stateEntries, err := stateBeforeEvent(ctx, b.db, *info, NIDs[eventID]) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") return nil diff --git a/roomserver/internal/perform_invite.go b/roomserver/internal/perform_invite.go index 1cfbcc18..6690de05 100644 --- a/roomserver/internal/perform_invite.go +++ b/roomserver/internal/perform_invite.go @@ -208,7 +208,7 @@ func buildInviteStrippedState( StateKey: "", }) } - roomState := state.NewStateResolution(db) + roomState := state.NewStateResolution(db, *info) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID, stateWanted, ) diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index 89716433..f8e8ba04 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -38,27 +38,22 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - response.RoomExists = false - return nil - } - - roomState := state.NewStateResolution(r.DB) - - info, err := r.DB.RoomInfo(ctx, request.RoomID) + roomInfo, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return err } - if info.IsStub { + if roomInfo == nil || roomInfo.IsStub { + response.RoomExists = false return nil } + + roomState := state.NewStateResolution(r.DB, *roomInfo) response.RoomExists = true - response.RoomVersion = roomVersion + response.RoomVersion = roomInfo.RoomVersion var currentStateSnapshotNID types.StateSnapshotNID response.LatestEvents, currentStateSnapshotNID, response.Depth, err = - r.DB.LatestEventIDs(ctx, info.RoomNID) + r.DB.LatestEventIDs(ctx, roomInfo.RoomNID) if err != nil { return err } @@ -85,7 +80,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) + response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion)) } return nil @@ -97,23 +92,17 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, ) error { - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) - if err != nil { - response.RoomExists = false - return nil - } - - roomState := state.NewStateResolution(r.DB) - info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return err } - if info.IsStub { + if info == nil || info.IsStub { return nil } + + roomState := state.NewStateResolution(r.DB, *info) response.RoomExists = true - response.RoomVersion = roomVersion + response.RoomVersion = info.RoomVersion prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs) if err != nil { @@ -128,7 +117,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( // Look up the currrent state for the requested tuples. stateEntries, err := roomState.LoadStateAfterEventsForStringTuples( - ctx, info.RoomNID, prevStates, request.StateToFetch, + ctx, prevStates, request.StateToFetch, ) if err != nil { return err @@ -140,7 +129,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, event.Headered(roomVersion)) + response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion)) } return nil @@ -168,7 +157,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID( } for _, event := range events { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) + roomVersion, verr := r.roomVersion(event.RoomID()) if verr != nil { return verr } @@ -277,7 +266,7 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID) + stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -297,8 +286,8 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( return nil } -func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db) +func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { + roomState := state.NewStateResolution(db, info) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -370,20 +359,28 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see return } - isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID()) + roomID := events[0].RoomID() + isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID) if err != nil { return } + info, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return err + } + if info == nil { + return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) + } response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( - ctx, request.EventID, request.ServerName, isServerInRoom, + ctx, *info, request.EventID, request.ServerName, isServerInRoom, ) return } func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( - ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, + ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { - roomState := state.NewStateResolution(r.DB) + roomState := state.NewStateResolution(r.DB, info) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { return false, err @@ -400,6 +397,7 @@ func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( } // QueryMissingEvents implements api.RoomserverInternalAPI +// nolint:gocyclo func (r *RoomserverInternalAPI) QueryMissingEvents( ctx context.Context, request *api.QueryMissingEventsRequest, @@ -418,8 +416,22 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( eventsToFilter[id] = true } } + events, err := r.DB.EventsFromIDs(ctx, front) + if err != nil { + return err + } + if len(events) == 0 { + return nil // we are missing the events being asked to search from, give up. + } + info, err := r.DB.RoomInfo(ctx, events[0].RoomID()) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) + } - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -432,7 +444,7 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter)) for _, event := range loadedEvents { if !eventsToFilter[event.EventID()] { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) + roomVersion, verr := r.roomVersion(event.RoomID()) if verr != nil { return verr } @@ -467,8 +479,16 @@ func (r *RoomserverInternalAPI) PerformBackfill( // this will include these events which is what we want front = request.PrevEventIDs() + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil || info.IsStub { + return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID) + } + // Scan the event tree for events to send back. - resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName) + resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } @@ -481,19 +501,14 @@ func (r *RoomserverInternalAPI) PerformBackfill( } for _, event := range loadedEvents { - roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID()) - if verr != nil { - return verr - } - - response.Events = append(response.Events, event.Headered(roomVersion)) + response.Events = append(response.Events, event.Headered(info.RoomVersion)) } return err } func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { - roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID) + roomVer, err := r.roomVersion(req.RoomID) if err != nil { return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) } @@ -642,7 +657,7 @@ func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, // TODO: Remove this when we have tests to assert correctness of this function // nolint:gocyclo func (r *RoomserverInternalAPI) scanEventTree( - ctx context.Context, front []string, visited map[string]bool, limit int, + ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int, serverName gomatrixserverlib.ServerName, ) ([]types.EventNID, error) { var resultNIDs []types.EventNID @@ -708,7 +723,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom) + allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", @@ -744,13 +759,13 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( if err != nil { return err } - if info.IsStub { + if info == nil || info.IsStub { return nil } response.RoomExists = true response.RoomVersion = info.RoomVersion - stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs) + stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs) if err != nil { return err } @@ -788,8 +803,8 @@ func (r *RoomserverInternalAPI) QueryStateAndAuthChain( return err } -func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { - roomState := state.NewStateResolution(r.DB) +func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) { + roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { @@ -937,15 +952,26 @@ func (r *RoomserverInternalAPI) QueryRoomVersionForRoom( return nil } - roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID) + info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return err } - response.RoomVersion = roomVersion + if info == nil { + return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID) + } + response.RoomVersion = info.RoomVersion r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion) return nil } +func (r *RoomserverInternalAPI) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) { + var res api.QueryRoomVersionForRoomResponse + err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{ + RoomID: roomID, + }, &res) + return res.RoomVersion, err +} + func (r *RoomserverInternalAPI) QueryPublishedRooms( ctx context.Context, req *api.QueryPublishedRoomsRequest, -- cgit v1.2.3