diff options
author | Kegsay <kegan@matrix.org> | 2020-04-29 18:41:45 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-29 18:41:45 +0100 |
commit | 4ad52c67cacc21997f49decd57e3105beb8ab62d (patch) | |
tree | e6606ee4fa49b9e62940c04a287c68e81c2b0749 /roomserver | |
parent | 458b3647815f0f2c6930611961431a9fb4390fba (diff) |
Honour history_visibility when backfilling (#990)
* Make backfill work for shared history visibility
* fetch missing state on backfill to remember snapshots correctly
* Fix gmsl to not mux in auth events into room state
* Whoops
* Linting
Diffstat (limited to 'roomserver')
-rw-r--r-- | roomserver/auth/auth.go | 4 | ||||
-rw-r--r-- | roomserver/query/backfill.go | 81 | ||||
-rw-r--r-- | roomserver/query/query.go | 108 | ||||
-rw-r--r-- | roomserver/state/state.go | 5 | ||||
-rw-r--r-- | roomserver/storage/sqlite3/events_table.go | 28 | ||||
-rw-r--r-- | roomserver/storage/sqlite3/storage.go | 2 |
6 files changed, 181 insertions, 47 deletions
diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go index 615a94b3..fdcf9f06 100644 --- a/roomserver/auth/auth.go +++ b/roomserver/auth/auth.go @@ -27,7 +27,7 @@ func IsServerAllowed( serverCurrentlyInRoom bool, authEvents []gomatrixserverlib.Event, ) bool { - historyVisibility := historyVisibilityForRoom(authEvents) + historyVisibility := HistoryVisibilityForRoom(authEvents) // 1. If the history_visibility was set to world_readable, allow. if historyVisibility == "world_readable" { @@ -52,7 +52,7 @@ func IsServerAllowed( return false } -func historyVisibilityForRoom(authEvents []gomatrixserverlib.Event) string { +func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string { // https://matrix.org/docs/spec/client_server/r0.6.0#id87 // By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared. visibility := "shared" diff --git a/roomserver/query/backfill.go b/roomserver/query/backfill.go index 09a515e9..f518de3e 100644 --- a/roomserver/query/backfill.go +++ b/roomserver/query/backfill.go @@ -3,6 +3,7 @@ package query import ( "context" + "github.com/matrix-org/dendrite/roomserver/auth" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -62,9 +63,9 @@ FederationHit: logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event") for _, srv := range b.servers { // hit any valid server c := gomatrixserverlib.FederatedStateProvider{ - FedClient: b.fedClient, - AuthEventsOnly: false, - Server: srv, + FedClient: b.fedClient, + RememberAuthEvents: false, + Server: srv, } res, err := c.StateIDsBeforeEvent(ctx, targetEvent) if err != nil { @@ -114,7 +115,9 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrix return nil } -func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { +func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, + event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) { + // try to fetch the events from the database first events, err := b.ProvideEvents(roomVer, eventIDs) if err != nil { @@ -133,9 +136,9 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr } c := gomatrixserverlib.FederatedStateProvider{ - FedClient: b.fedClient, - AuthEventsOnly: false, - Server: b.servers[0], + FedClient: b.fedClient, + RememberAuthEvents: false, + Server: b.servers[0], } result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) if err != nil { @@ -160,18 +163,33 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID return } + stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID]) + if err != nil { + logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event") + return + } + + // possibly return all joined servers depending on history visiblity + memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries) + if err != nil { + logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") + return + } + logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis)) + // Retrieve all "m.room.member" state events of "join" membership, which // contains the list of users in the room before the event, therefore all // the servers in it at that moment. - events, err := getMembershipsBeforeEventNID(ctx, b.db, NIDs[eventID], true) + memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true) if err != nil { logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event") return } + memberEvents = append(memberEvents, memberEventsFromVis...) // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[gomatrixserverlib.ServerName]bool) - for _, event := range events { + for _, event := range memberEvents { serverSet[event.Origin()] = true } for server := range serverSet { @@ -186,7 +204,9 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) { +func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, + fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) { + tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs) return &tx, err } @@ -215,3 +235,44 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, } return events, nil } + +// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility. +// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just +// pull all events and then filter by that table. +func joinEventsFromHistoryVisibility( + ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) { + + var eventNIDs []types.EventNID + for _, entry := range stateEntries { + // Filter the events to retrieve to only keep the membership events + if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID { + eventNIDs = append(eventNIDs, entry.EventNID) + break + } + } + + // Get all of the events in this state + stateEvents, err := db.Events(ctx, eventNIDs) + if err != nil { + return nil, err + } + events := make([]gomatrixserverlib.Event, len(stateEvents)) + for i := range stateEvents { + events[i] = stateEvents[i].Event + } + visibility := auth.HistoryVisibilityForRoom(events) + if visibility != "shared" { + logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility) + return nil, nil + } + // get joined members + roomNID, err := db.RoomNID(ctx, roomID) + if err != nil { + return nil, err + } + joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true) + if err != nil { + return nil, err + } + return db.Events(ctx, joinEventNIDs) +} diff --git a/roomserver/query/query.go b/roomserver/query/query.go index a54fa58d..6778ac28 100644 --- a/roomserver/query/query.go +++ b/roomserver/query/query.go @@ -277,6 +277,7 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( response.JoinEvents = []gomatrixserverlib.ClientEvent{} var events []types.Event + var stateEntries []types.StateEntry if stillInRoom { var eventNIDs []types.EventNID eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly) @@ -286,7 +287,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, eventNIDs) } else { - events, err = getMembershipsBeforeEventNID(ctx, r.DB, membershipEventNID, request.JoinedOnly) + stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID) + if err != nil { + logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") + return err + } + events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly) } if err != nil { @@ -301,15 +307,8 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom( return nil } -// getMembershipsBeforeEventNID takes the numeric ID of an event and fetches the state -// of the event's room as it was when this event was fired, then filters the state events to -// only keep the "m.room.member" events with a "join" membership. These events are returned. -// Returns an error if there was an issue fetching the events. -func getMembershipsBeforeEventNID( - ctx context.Context, db storage.Database, eventNID types.EventNID, joinedOnly bool, -) ([]types.Event, error) { +func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) { roomState := state.NewStateResolution(db) - events := []types.Event{} // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -323,10 +322,15 @@ func getMembershipsBeforeEventNID( } // Fetch the state as it was when this event was fired - stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, prevState) - if err != nil { - return nil, err - } + return roomState.LoadCombinedStateAfterEvents(ctx, prevState) +} + +// getMembershipsAtState filters the state events to +// only keep the "m.room.member" events with a "join" membership. These events are returned. +// Returns an error if there was an issue fetching the events. +func getMembershipsAtState( + ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool, +) ([]types.Event, error) { var eventNIDs []types.EventNID for _, entry := range stateEntries { @@ -347,6 +351,7 @@ func getMembershipsBeforeEventNID( } // Filter the events to only keep the "join" membership events + var events []types.Event for _, event := range stateEvents { membership, err := event.Membership() if err != nil { @@ -563,20 +568,29 @@ func (r *RoomserverQueryAPI) backfillViaFederation(ctx context.Context, req *api if !ok { // this should be impossible as all events returned must have pass Step 5 of the PDU checks // which requires a list of state IDs. - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to find state IDs for event which passed auth checks") + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks") continue } var entries []types.StateEntry if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil { - return err + // attempt to fetch the missing events + r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs) + // try again + entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs) + if err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event") + return err + } } var beforeStateSnapshotNID types.StateSnapshotNID if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") return err } + util.GetLogger(ctx).Infof("Backfilled event %s (nid=%d) getting snapshot %v with entries %+v", ev.EventID(), ev.EventNID, beforeStateSnapshotNID, entries) if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to set state before event") + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid") } } @@ -608,6 +622,66 @@ func (r *RoomserverQueryAPI) isServerCurrentlyInRoom(ctx context.Context, server return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil } +// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just +// best effort. +func (r *RoomserverQueryAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, + backfillRequester *backfillRequester, stateIDs []string) { + + servers := backfillRequester.servers + + // work out which are missing + nidMap, err := r.DB.EventNIDs(ctx, stateIDs) + if err != nil { + util.GetLogger(ctx).WithError(err).Warn("cannot query missing events") + return + } + missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event + for _, id := range stateIDs { + if _, ok := nidMap[id]; !ok { + missingMap[id] = nil + } + } + util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers)) + + // fetch the events from federation. Loop the servers first so if we find one that works we stick with them + for _, srv := range servers { + for id, ev := range missingMap { + if ev != nil { + continue // already found + } + logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) + res, err := r.FedClient.GetEvent(ctx, srv, id) + if err != nil { + logger.WithError(err).Warn("failed to get event from server") + continue + } + loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + if err != nil { + logger.WithError(err).Warn("failed to load and verify event") + continue + } + logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result) + for _, res := range result { + if res.Error != nil { + logger.WithError(err).Warn("event failed PDU checks") + continue + } + missingMap[id] = res.Event + } + } + } + + var newEvents []gomatrixserverlib.HeaderedEvent + for _, ev := range missingMap { + if ev != nil { + newEvents = append(newEvents, *ev) + } + } + util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents)) + persistEvents(ctx, r.DB, newEvents) +} + // TODO: Remove this when we have tests to assert correctness of this function // nolint:gocyclo func (r *RoomserverQueryAPI) scanEventTree( @@ -857,7 +931,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse var stateAtEvent types.StateAtEvent roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) if err != nil { - logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to store backfilled event") + logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue } backfilledEventMap[ev.EventID()] = types.Event{ diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 389c9440..9b005ee6 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -86,7 +86,10 @@ func (v StateResolution) LoadStateAtEvent( ) ([]types.StateEntry, error) { snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID) if err != nil { - return nil, err + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err) + } + if snapshotNID == 0 { + return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) } stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index d881fa91..a63596ae 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -48,11 +48,6 @@ const insertEventSQL = ` ON CONFLICT DO NOTHING; ` -const insertEventResultSQL = ` - SELECT event_nid, state_snapshot_nid FROM roomserver_events - WHERE rowid = last_insert_rowid(); -` - const selectEventSQL = "" + "SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1" @@ -102,7 +97,6 @@ const selectRoomNIDForEventNIDSQL = "" + type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt - insertEventResultStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt @@ -126,7 +120,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertEventStmt, insertEventSQL}, - {&s.insertEventResultStmt, insertEventResultSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, @@ -152,19 +145,22 @@ func (s *eventStatements) insertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, -) (types.EventNID, types.StateSnapshotNID, error) { - var eventNID int64 - var stateNID int64 - var err error +) (types.EventNID, error) { + // attempt to insert: the last_row_id is the event NID insertStmt := common.TxStmt(txn, s.insertEventStmt) - resultStmt := common.TxStmt(txn, s.insertEventResultStmt) - if _, err = insertStmt.ExecContext( + result, err := insertStmt.ExecContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, - ); err == nil { - err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID) + ) + if err != nil { + return 0, err } - return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err + modified, err := result.RowsAffected() + if modified == 0 && err == nil { + return 0, sql.ErrNoRows + } + eventNID, err := result.LastInsertId() + return types.EventNID(eventNID), err } func (s *eventStatements) selectEvent( diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 5df9c4e0..b6e846df 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -124,7 +124,7 @@ func (d *Database) StoreEvent( } } - if eventNID, stateNID, err = d.statements.insertEvent( + if eventNID, err = d.statements.insertEvent( ctx, txn, roomNID, |