aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query.go')
-rw-r--r--roomserver/internal/query.go466
1 files changed, 12 insertions, 454 deletions
diff --git a/roomserver/internal/query.go b/roomserver/internal/query.go
index f8e8ba04..26b22c74 100644
--- a/roomserver/internal/query.go
+++ b/roomserver/internal/query.go
@@ -20,11 +20,9 @@ import (
"context"
"fmt"
- "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
- "github.com/matrix-org/dendrite/roomserver/auth"
+ "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
- "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib"
@@ -74,7 +72,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
return err
}
- stateEvents, err := r.loadStateEvents(ctx, stateEntries)
+ stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
if err != nil {
return err
}
@@ -123,7 +121,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
return err
}
- stateEvents, err := r.loadStateEvents(ctx, stateEntries)
+ stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
if err != nil {
return err
}
@@ -151,7 +149,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID(
eventNIDs = append(eventNIDs, nid)
}
- events, err := r.loadEvents(ctx, eventNIDs)
+ events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs)
if err != nil {
return err
}
@@ -168,31 +166,6 @@ func (r *RoomserverInternalAPI) QueryEventsByID(
return nil
}
-func (r *RoomserverInternalAPI) loadStateEvents(
- ctx context.Context, stateEntries []types.StateEntry,
-) ([]gomatrixserverlib.Event, error) {
- eventNIDs := make([]types.EventNID, len(stateEntries))
- for i := range stateEntries {
- eventNIDs[i] = stateEntries[i].EventNID
- }
- return r.loadEvents(ctx, eventNIDs)
-}
-
-func (r *RoomserverInternalAPI) loadEvents(
- ctx context.Context, eventNIDs []types.EventNID,
-) ([]gomatrixserverlib.Event, error) {
- stateEvents, err := r.DB.Events(ctx, eventNIDs)
- if err != nil {
- return nil, err
- }
-
- result := make([]gomatrixserverlib.Event, len(stateEvents))
- for i := range stateEvents {
- result[i] = stateEvents[i].Event
- }
- return result, nil
-}
-
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryMembershipForUser(
ctx context.Context,
@@ -266,12 +239,12 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, eventNIDs)
} else {
- stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
+ stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
- events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
+ events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
}
if err != nil {
@@ -286,65 +259,6 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
return nil
}
-func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
- roomState := state.NewStateResolution(db, info)
- // Lookup the event NID
- eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
- if err != nil {
- return nil, err
- }
- eventIDs := []string{eIDs[eventNID]}
-
- prevState, err := db.StateAtEventIDs(ctx, eventIDs)
- if err != nil {
- return nil, err
- }
-
- // Fetch the state as it was when this event was fired
- return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
-}
-
-// getMembershipsAtState filters the state events to
-// only keep the "m.room.member" events with a "join" membership. These events are returned.
-// Returns an error if there was an issue fetching the events.
-func getMembershipsAtState(
- ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
-) ([]types.Event, error) {
-
- var eventNIDs []types.EventNID
- for _, entry := range stateEntries {
- // Filter the events to retrieve to only keep the membership events
- if entry.EventTypeNID == types.MRoomMemberNID {
- eventNIDs = append(eventNIDs, entry.EventNID)
- }
- }
-
- // Get all of the events in this state
- stateEvents, err := db.Events(ctx, eventNIDs)
- if err != nil {
- return nil, err
- }
-
- if !joinedOnly {
- return stateEvents, nil
- }
-
- // Filter the events to only keep the "join" membership events
- var events []types.Event
- for _, event := range stateEvents {
- membership, err := event.Membership()
- if err != nil {
- return nil, err
- }
-
- if membership == gomatrixserverlib.Join {
- events = append(events, event)
- }
- }
-
- return events, nil
-}
-
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
ctx context.Context,
@@ -360,7 +274,7 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
return
}
roomID := events[0].RoomID()
- isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID)
+ isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID)
if err != nil {
return
}
@@ -371,31 +285,12 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
if info == nil {
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
}
- response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
- ctx, *info, request.EventID, request.ServerName, isServerInRoom,
+ response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
+ ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom,
)
return
}
-func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
- ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
-) (bool, error) {
- roomState := state.NewStateResolution(r.DB, info)
- stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
- if err != nil {
- return false, err
- }
-
- // TODO: We probably want to make it so that we don't have to pull
- // out all the state if possible.
- stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
- if err != nil {
- return false, err
- }
-
- return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
-}
-
// QueryMissingEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *RoomserverInternalAPI) QueryMissingEvents(
@@ -431,12 +326,12 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
}
- resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
+ resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
- loadedEvents, err := r.loadEvents(ctx, resultNIDs)
+ loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs)
if err != nil {
return err
}
@@ -456,299 +351,6 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
return err
}
-// PerformBackfill implements api.RoomServerQueryAPI
-func (r *RoomserverInternalAPI) PerformBackfill(
- ctx context.Context,
- request *api.PerformBackfillRequest,
- response *api.PerformBackfillResponse,
-) error {
- // if we are requesting the backfill then we need to do a federation hit
- // TODO: we could be more sensible and fetch as many events we already have then request the rest
- // which is what the syncapi does already.
- if request.ServerName == r.ServerName {
- return r.backfillViaFederation(ctx, request, response)
- }
- // someone else is requesting the backfill, try to service their request.
- var err error
- var front []string
-
- // The limit defines the maximum number of events to retrieve, so it also
- // defines the highest number of elements in the map below.
- visited := make(map[string]bool, request.Limit)
-
- // this will include these events which is what we want
- front = request.PrevEventIDs()
-
- info, err := r.DB.RoomInfo(ctx, request.RoomID)
- if err != nil {
- return err
- }
- if info == nil || info.IsStub {
- return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
- }
-
- // Scan the event tree for events to send back.
- resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
- if err != nil {
- return err
- }
-
- // Retrieve events from the list that was filled previously.
- var loadedEvents []gomatrixserverlib.Event
- loadedEvents, err = r.loadEvents(ctx, resultNIDs)
- if err != nil {
- return err
- }
-
- for _, event := range loadedEvents {
- response.Events = append(response.Events, event.Headered(info.RoomVersion))
- }
-
- return err
-}
-
-func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error {
- roomVer, err := r.roomVersion(req.RoomID)
- if err != nil {
- return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err)
- }
- requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities)
- // Request 100 items regardless of what the query asks for.
- // We don't want to go much higher than this.
- // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
- // (so we don't need to hit /state_ids which the test has no listener for)
- // Specifically the test "Outbound federation can backfill events"
- events, err := gomatrixserverlib.RequestBackfill(
- ctx, requester,
- r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100)
- if err != nil {
- return err
- }
- logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
-
- // persist these new events - auth checks have already been done
- roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
- if err != nil {
- return err
- }
-
- for _, ev := range backfilledEventMap {
- // now add state for these events
- stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
- if !ok {
- // this should be impossible as all events returned must have pass Step 5 of the PDU checks
- // which requires a list of state IDs.
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
- continue
- }
- var entries []types.StateEntry
- if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
- // attempt to fetch the missing events
- r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs)
- // try again
- entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
- if err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
- return err
- }
- }
-
- var beforeStateSnapshotNID types.StateSnapshotNID
- if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
- return err
- }
- if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
- }
- }
-
- // TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
-
- res.Events = events
- return nil
-}
-
-func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
- info, err := r.DB.RoomInfo(ctx, roomID)
- if err != nil {
- return false, err
- }
- if info == nil {
- return false, fmt.Errorf("unknown room %s", roomID)
- }
-
- eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
- if err != nil {
- return false, err
- }
-
- events, err := r.DB.Events(ctx, eventNIDs)
- if err != nil {
- return false, err
- }
- gmslEvents := make([]gomatrixserverlib.Event, len(events))
- for i := range events {
- gmslEvents[i] = events[i].Event
- }
- return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
-}
-
-// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
-// best effort.
-func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
- backfillRequester *backfillRequester, stateIDs []string) {
-
- servers := backfillRequester.servers
-
- // work out which are missing
- nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
- if err != nil {
- util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
- return
- }
- missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
- for _, id := range stateIDs {
- if _, ok := nidMap[id]; !ok {
- missingMap[id] = nil
- }
- }
- util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
-
- // fetch the events from federation. Loop the servers first so if we find one that works we stick with them
- for _, srv := range servers {
- for id, ev := range missingMap {
- if ev != nil {
- continue // already found
- }
- logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
- res, err := r.FedClient.GetEvent(ctx, srv, id)
- if err != nil {
- logger.WithError(err).Warn("failed to get event from server")
- continue
- }
- loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
- result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
- if err != nil {
- logger.WithError(err).Warn("failed to load and verify event")
- continue
- }
- logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
- for _, res := range result {
- if res.Error != nil {
- logger.WithError(err).Warn("event failed PDU checks")
- continue
- }
- missingMap[id] = res.Event
- }
- }
- }
-
- var newEvents []gomatrixserverlib.HeaderedEvent
- for _, ev := range missingMap {
- if ev != nil {
- newEvents = append(newEvents, *ev)
- }
- }
- util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
- persistEvents(ctx, r.DB, newEvents)
-}
-
-// TODO: Remove this when we have tests to assert correctness of this function
-// nolint:gocyclo
-func (r *RoomserverInternalAPI) scanEventTree(
- ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int,
- serverName gomatrixserverlib.ServerName,
-) ([]types.EventNID, error) {
- var resultNIDs []types.EventNID
- var err error
- var allowed bool
- var events []types.Event
- var next []string
- var pre string
-
- // TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
- // Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
- // so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
- // duplicate events being sent in response to /backfill requests.
- initialIgnoreList := make(map[string]bool, len(visited))
- for k, v := range visited {
- initialIgnoreList[k] = v
- }
-
- resultNIDs = make([]types.EventNID, 0, limit)
-
- var checkedServerInRoom bool
- var isServerInRoom bool
-
- // Loop through the event IDs to retrieve the requested events and go
- // through the whole tree (up to the provided limit) using the events'
- // "prev_event" key.
-BFSLoop:
- for len(front) > 0 {
- // Prevent unnecessary allocations: reset the slice only when not empty.
- if len(next) > 0 {
- next = make([]string, 0)
- }
- // Retrieve the events to process from the database.
- events, err = r.DB.EventsFromIDs(ctx, front)
- if err != nil {
- return resultNIDs, err
- }
-
- if !checkedServerInRoom && len(events) > 0 {
- // It's nasty that we have to extract the room ID from an event, but many federation requests
- // only talk in event IDs, no room IDs at all (!!!)
- ev := events[0]
- isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID())
- if err != nil {
- util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
- }
- checkedServerInRoom = true
- }
-
- for _, ev := range events {
- // Break out of the loop if the provided limit is reached.
- if len(resultNIDs) == limit {
- break BFSLoop
- }
-
- if !initialIgnoreList[ev.EventID()] {
- // Update the list of events to retrieve.
- resultNIDs = append(resultNIDs, ev.EventNID)
- }
- // Loop through the event's parents.
- for _, pre = range ev.PrevEventIDs() {
- // Only add an event to the list of next events to process if it
- // hasn't been seen before.
- if !visited[pre] {
- visited[pre] = true
- allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom)
- if err != nil {
- util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
- "Error checking if allowed to see event",
- )
- return resultNIDs, err
- }
-
- // If the event hasn't been seen before and the HS
- // requesting to retrieve it is allowed to do so, add it to
- // the list of events to retrieve.
- if allowed {
- next = append(next, pre)
- } else {
- util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
- }
- }
- }
- }
- // Repeat the same process with the parent events we just processed.
- front = next
- }
-
- return resultNIDs, err
-}
-
// QueryStateAndAuthChain implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
ctx context.Context,
@@ -823,7 +425,7 @@ func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInf
return nil, err
}
- return r.loadStateEvents(ctx, stateEntries)
+ return helpers.LoadStateEvents(ctx, r.DB, stateEntries)
}
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
@@ -879,50 +481,6 @@ func getAuthChain(
return authEvents, nil
}
-func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
- var roomNID types.RoomNID
- backfilledEventMap := make(map[string]types.Event)
- for j, ev := range events {
- nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
- if err != nil { // this shouldn't happen as RequestBackfill already found them
- logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
- continue
- }
- authNids := make([]types.EventNID, len(nidMap))
- i := 0
- for _, nid := range nidMap {
- authNids[i] = nid
- i++
- }
- var stateAtEvent types.StateAtEvent
- var redactedEventID string
- var redactionEvent *gomatrixserverlib.Event
- roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
- if err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
- continue
- }
- // If storing this event results in it being redacted, then do so.
- // It's also possible for this event to be a redaction which results in another event being
- // redacted, which we don't care about since we aren't returning it in this backfill.
- if redactedEventID == ev.EventID() {
- eventToRedact := ev.Unwrap()
- redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact)
- if err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
- continue
- }
- ev = redactedEvent.Headered(ev.RoomVersion)
- events[j] = ev
- }
- backfilledEventMap[ev.EventID()] = types.Event{
- EventNID: stateAtEvent.StateEntry.EventNID,
- Event: ev.Unwrap(),
- }
- }
- return roomNID, backfilledEventMap
-}
-
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities(
ctx context.Context,