diff options
Diffstat (limited to 'roomserver/internal/query.go')
-rw-r--r-- | roomserver/internal/query.go | 466 |
1 files changed, 12 insertions, 454 deletions
diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go index f8e8ba04..26b22c74 100644 --- a/roomserver/internal/query.go +++ b/roomserver/internal/query.go @@ -20,11 +20,9 @@ import ( "context" "fmt" - "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/auth" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" - "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/version" "github.com/matrix-org/gomatrixserverlib" @@ -74,7 +72,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState( return err } - stateEvents, err := r.loadStateEvents(ctx, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) if err != nil { return err } @@ -123,7 +121,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents( return err } - stateEvents, err := r.loadStateEvents(ctx, stateEntries) + stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries) if err != nil { return err } @@ -151,7 +149,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID( eventNIDs = append(eventNIDs, nid) } - events, err := r.loadEvents(ctx, eventNIDs) + events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs) if err != nil { return err } @@ -168,31 +166,6 @@ func (r *RoomserverInternalAPI) QueryEventsByID( return nil } -func (r *RoomserverInternalAPI) loadStateEvents( - ctx context.Context, stateEntries []types.StateEntry, -) ([]gomatrixserverlib.Event, error) { - eventNIDs := make([]types.EventNID, len(stateEntries)) - for i := range stateEntries { - eventNIDs[i] = stateEntries[i].EventNID - } - return r.loadEvents(ctx, eventNIDs) -} - -func (r *RoomserverInternalAPI) loadEvents( - ctx context.Context, eventNIDs []types.EventNID, -) ([]gomatrixserverlib.Event, error) { - stateEvents, err := r.DB.Events(ctx, eventNIDs) - if err != nil { - return nil, err - } - - result := make([]gomatrixserverlib.Event, len(stateEvents)) - for i := range stateEvents { - result[i] = stateEvents[i].Event - } - return result, nil -} - // QueryMembershipForUser implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryMembershipForUser( ctx context.Context, @@ -266,12 +239,12 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err } - events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) + events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) } if err != nil { @@ -286,65 +259,6 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom( return nil } -func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) - // Lookup the event NID - eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) - if err != nil { - return nil, err - } - eventIDs := []string{eIDs[eventNID]} - - prevState, err := db.StateAtEventIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - // Fetch the state as it was when this event was fired - return roomState.LoadCombinedStateAfterEvents(ctx, prevState) -} - -// getMembershipsAtState filters the state events to -// only keep the "m.room.member" events with a "join" membership. These events are returned. -// Returns an error if there was an issue fetching the events. -func getMembershipsAtState( - ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, -) ([]types.Event, error) { - - var eventNIDs []types.EventNID - for _, entry := range stateEntries { - // Filter the events to retrieve to only keep the membership events - if entry.EventTypeNID == types.MRoomMemberNID { - eventNIDs = append(eventNIDs, entry.EventNID) - } - } - - // Get all of the events in this state - stateEvents, err := db.Events(ctx, eventNIDs) - if err != nil { - return nil, err - } - - if !joinedOnly { - return stateEvents, nil - } - - // Filter the events to only keep the "join" membership events - var events []types.Event - for _, event := range stateEvents { - membership, err := event.Membership() - if err != nil { - return nil, err - } - - if membership == gomatrixserverlib.Join { - events = append(events, event) - } - } - - return events, nil -} - // QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( ctx context.Context, @@ -360,7 +274,7 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( return } roomID := events[0].RoomID() - isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID) + isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID) if err != nil { return } @@ -371,31 +285,12 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent( if info == nil { return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) } - response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent( - ctx, *info, request.EventID, request.ServerName, isServerInRoom, + response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( + ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom, ) return } -func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent( - ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, -) (bool, error) { - roomState := state.NewStateResolution(r.DB, info) - stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) - if err != nil { - return false, err - } - - // TODO: We probably want to make it so that we don't have to pull - // out all the state if possible. - stateAtEvent, err := r.loadStateEvents(ctx, stateEntries) - if err != nil { - return false, err - } - - return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil -} - // QueryMissingEvents implements api.RoomserverInternalAPI // nolint:gocyclo func (r *RoomserverInternalAPI) QueryMissingEvents( @@ -431,12 +326,12 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID()) } - resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) + resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName) if err != nil { return err } - loadedEvents, err := r.loadEvents(ctx, resultNIDs) + loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs) if err != nil { return err } @@ -456,299 +351,6 @@ func (r *RoomserverInternalAPI) QueryMissingEvents( return err } -// PerformBackfill implements api.RoomServerQueryAPI -func (r *RoomserverInternalAPI) PerformBackfill( - ctx context.Context, - request *api.PerformBackfillRequest, - response *api.PerformBackfillResponse, -) error { - // if we are requesting the backfill then we need to do a federation hit - // TODO: we could be more sensible and fetch as many events we already have then request the rest - // which is what the syncapi does already. - if request.ServerName == r.ServerName { - return r.backfillViaFederation(ctx, request, response) - } - // someone else is requesting the backfill, try to service their request. - var err error - var front []string - - // The limit defines the maximum number of events to retrieve, so it also - // defines the highest number of elements in the map below. - visited := make(map[string]bool, request.Limit) - - // this will include these events which is what we want - front = request.PrevEventIDs() - - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info == nil || info.IsStub { - return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID) - } - - // Scan the event tree for events to send back. - resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName) - if err != nil { - return err - } - - // Retrieve events from the list that was filled previously. - var loadedEvents []gomatrixserverlib.Event - loadedEvents, err = r.loadEvents(ctx, resultNIDs) - if err != nil { - return err - } - - for _, event := range loadedEvents { - response.Events = append(response.Events, event.Headered(info.RoomVersion)) - } - - return err -} - -func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error { - roomVer, err := r.roomVersion(req.RoomID) - if err != nil { - return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err) - } - requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities) - // Request 100 items regardless of what the query asks for. - // We don't want to go much higher than this. - // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass - // (so we don't need to hit /state_ids which the test has no listener for) - // Specifically the test "Outbound federation can backfill events" - events, err := gomatrixserverlib.RequestBackfill( - ctx, requester, - r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100) - if err != nil { - return err - } - logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) - - // persist these new events - auth checks have already been done - roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) - if err != nil { - return err - } - - for _, ev := range backfilledEventMap { - // now add state for these events - stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()] - if !ok { - // this should be impossible as all events returned must have pass Step 5 of the PDU checks - // which requires a list of state IDs. - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks") - continue - } - var entries []types.StateEntry - if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { - // attempt to fetch the missing events - r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs) - // try again - entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") - return err - } - } - - var beforeStateSnapshotNID types.StateSnapshotNID - if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") - return err - } - if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid") - } - } - - // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point. - - res.Events = events - return nil -} - -func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) { - info, err := r.DB.RoomInfo(ctx, roomID) - if err != nil { - return false, err - } - if info == nil { - return false, fmt.Errorf("unknown room %s", roomID) - } - - eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) - if err != nil { - return false, err - } - - events, err := r.DB.Events(ctx, eventNIDs) - if err != nil { - return false, err - } - gmslEvents := make([]gomatrixserverlib.Event, len(events)) - for i := range events { - gmslEvents[i] = events[i].Event - } - return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil -} - -// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just -// best effort. -func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string) { - - servers := backfillRequester.servers - - // work out which are missing - nidMap, err := r.DB.EventNIDs(ctx, stateIDs) - if err != nil { - util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") - return - } - missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event - for _, id := range stateIDs { - if _, ok := nidMap[id]; !ok { - missingMap[id] = nil - } - } - util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers)) - - // fetch the events from federation. Loop the servers first so if we find one that works we stick with them - for _, srv := range servers { - for id, ev := range missingMap { - if ev != nil { - continue // already found - } - logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FedClient.GetEvent(ctx, srv, id) - if err != nil { - logger.WithError(err).Warn("failed to get event from server") - continue - } - loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) - if err != nil { - logger.WithError(err).Warn("failed to load and verify event") - continue - } - logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result) - for _, res := range result { - if res.Error != nil { - logger.WithError(err).Warn("event failed PDU checks") - continue - } - missingMap[id] = res.Event - } - } - } - - var newEvents []gomatrixserverlib.HeaderedEvent - for _, ev := range missingMap { - if ev != nil { - newEvents = append(newEvents, *ev) - } - } - util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) - persistEvents(ctx, r.DB, newEvents) -} - -// TODO: Remove this when we have tests to assert correctness of this function -// nolint:gocyclo -func (r *RoomserverInternalAPI) scanEventTree( - ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int, - serverName gomatrixserverlib.ServerName, -) ([]types.EventNID, error) { - var resultNIDs []types.EventNID - var err error - var allowed bool - var events []types.Event - var next []string - var pre string - - // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be) - // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing - // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in - // duplicate events being sent in response to /backfill requests. - initialIgnoreList := make(map[string]bool, len(visited)) - for k, v := range visited { - initialIgnoreList[k] = v - } - - resultNIDs = make([]types.EventNID, 0, limit) - - var checkedServerInRoom bool - var isServerInRoom bool - - // Loop through the event IDs to retrieve the requested events and go - // through the whole tree (up to the provided limit) using the events' - // "prev_event" key. -BFSLoop: - for len(front) > 0 { - // Prevent unnecessary allocations: reset the slice only when not empty. - if len(next) > 0 { - next = make([]string, 0) - } - // Retrieve the events to process from the database. - events, err = r.DB.EventsFromIDs(ctx, front) - if err != nil { - return resultNIDs, err - } - - if !checkedServerInRoom && len(events) > 0 { - // It's nasty that we have to extract the room ID from an event, but many federation requests - // only talk in event IDs, no room IDs at all (!!!) - ev := events[0] - isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID()) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.") - } - checkedServerInRoom = true - } - - for _, ev := range events { - // Break out of the loop if the provided limit is reached. - if len(resultNIDs) == limit { - break BFSLoop - } - - if !initialIgnoreList[ev.EventID()] { - // Update the list of events to retrieve. - resultNIDs = append(resultNIDs, ev.EventNID) - } - // Loop through the event's parents. - for _, pre = range ev.PrevEventIDs() { - // Only add an event to the list of next events to process if it - // hasn't been seen before. - if !visited[pre] { - visited[pre] = true - allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom) - if err != nil { - util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( - "Error checking if allowed to see event", - ) - return resultNIDs, err - } - - // If the event hasn't been seen before and the HS - // requesting to retrieve it is allowed to do so, add it to - // the list of events to retrieve. - if allowed { - next = append(next, pre) - } else { - util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event") - } - } - } - } - // Repeat the same process with the parent events we just processed. - front = next - } - - return resultNIDs, err -} - // QueryStateAndAuthChain implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryStateAndAuthChain( ctx context.Context, @@ -823,7 +425,7 @@ func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInf return nil, err } - return r.loadStateEvents(ctx, stateEntries) + return helpers.LoadStateEvents(ctx, r.DB, stateEntries) } type eventsFromIDs func(context.Context, []string) ([]types.Event, error) @@ -879,50 +481,6 @@ func getAuthChain( return authEvents, nil } -func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) { - var roomNID types.RoomNID - backfilledEventMap := make(map[string]types.Event) - for j, ev := range events { - nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs()) - if err != nil { // this shouldn't happen as RequestBackfill already found them - logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events") - continue - } - authNids := make([]types.EventNID, len(nidMap)) - i := 0 - for _, nid := range nidMap { - authNids[i] = nid - i++ - } - var stateAtEvent types.StateAtEvent - var redactedEventID string - var redactionEvent *gomatrixserverlib.Event - roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") - continue - } - // If storing this event results in it being redacted, then do so. - // It's also possible for this event to be a redaction which results in another event being - // redacted, which we don't care about since we aren't returning it in this backfill. - if redactedEventID == ev.EventID() { - eventToRedact := ev.Unwrap() - redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact) - if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event") - continue - } - ev = redactedEvent.Headered(ev.RoomVersion) - events[j] = ev - } - backfilledEventMap[ev.EventID()] = types.Event{ - EventNID: stateAtEvent.StateEntry.EventNID, - Event: ev.Unwrap(), - } - } - return roomNID, backfilledEventMap -} - // QueryRoomVersionCapabilities implements api.RoomserverInternalAPI func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities( ctx context.Context, |