aboutsummaryrefslogtreecommitdiff
path: root/roomserver
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-04-29 18:41:45 +0100
committerGitHub <noreply@github.com>2020-04-29 18:41:45 +0100
commit4ad52c67cacc21997f49decd57e3105beb8ab62d (patch)
treee6606ee4fa49b9e62940c04a287c68e81c2b0749 /roomserver
parent458b3647815f0f2c6930611961431a9fb4390fba (diff)
Honour history_visibility when backfilling (#990)
* Make backfill work for shared history visibility * fetch missing state on backfill to remember snapshots correctly * Fix gmsl to not mux in auth events into room state * Whoops * Linting
Diffstat (limited to 'roomserver')
-rw-r--r--roomserver/auth/auth.go4
-rw-r--r--roomserver/query/backfill.go81
-rw-r--r--roomserver/query/query.go108
-rw-r--r--roomserver/state/state.go5
-rw-r--r--roomserver/storage/sqlite3/events_table.go28
-rw-r--r--roomserver/storage/sqlite3/storage.go2
6 files changed, 181 insertions, 47 deletions
diff --git a/roomserver/auth/auth.go b/roomserver/auth/auth.go
index 615a94b3..fdcf9f06 100644
--- a/roomserver/auth/auth.go
+++ b/roomserver/auth/auth.go
@@ -27,7 +27,7 @@ func IsServerAllowed(
serverCurrentlyInRoom bool,
authEvents []gomatrixserverlib.Event,
) bool {
- historyVisibility := historyVisibilityForRoom(authEvents)
+ historyVisibility := HistoryVisibilityForRoom(authEvents)
// 1. If the history_visibility was set to world_readable, allow.
if historyVisibility == "world_readable" {
@@ -52,7 +52,7 @@ func IsServerAllowed(
return false
}
-func historyVisibilityForRoom(authEvents []gomatrixserverlib.Event) string {
+func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string {
// https://matrix.org/docs/spec/client_server/r0.6.0#id87
// By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared.
visibility := "shared"
diff --git a/roomserver/query/backfill.go b/roomserver/query/backfill.go
index 09a515e9..f518de3e 100644
--- a/roomserver/query/backfill.go
+++ b/roomserver/query/backfill.go
@@ -3,6 +3,7 @@ package query
import (
"context"
+ "github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -62,9 +63,9 @@ FederationHit:
logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event")
for _, srv := range b.servers { // hit any valid server
c := gomatrixserverlib.FederatedStateProvider{
- FedClient: b.fedClient,
- AuthEventsOnly: false,
- Server: srv,
+ FedClient: b.fedClient,
+ RememberAuthEvents: false,
+ Server: srv,
}
res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
if err != nil {
@@ -114,7 +115,9 @@ func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrix
return nil
}
-func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
+func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
+ event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
+
// try to fetch the events from the database first
events, err := b.ProvideEvents(roomVer, eventIDs)
if err != nil {
@@ -133,9 +136,9 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr
}
c := gomatrixserverlib.FederatedStateProvider{
- FedClient: b.fedClient,
- AuthEventsOnly: false,
- Server: b.servers[0],
+ FedClient: b.fedClient,
+ RememberAuthEvents: false,
+ Server: b.servers[0],
}
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
if err != nil {
@@ -160,18 +163,33 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
return
}
+ stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID])
+ if err != nil {
+ logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
+ return
+ }
+
+ // possibly return all joined servers depending on history visiblity
+ memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries)
+ if err != nil {
+ logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
+ return
+ }
+ logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis))
+
// Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all
// the servers in it at that moment.
- events, err := getMembershipsBeforeEventNID(ctx, b.db, NIDs[eventID], true)
+ memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true)
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return
}
+ memberEvents = append(memberEvents, memberEventsFromVis...)
// Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[gomatrixserverlib.ServerName]bool)
- for _, event := range events {
+ for _, event := range memberEvents {
serverSet[event.Origin()] = true
}
for server := range serverSet {
@@ -186,7 +204,9 @@ func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID
// Backfill performs a backfill request to the given server.
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
-func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
+func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string,
+ fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
+
tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs)
return &tx, err
}
@@ -215,3 +235,44 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
}
return events, nil
}
+
+// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility.
+// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
+// pull all events and then filter by that table.
+func joinEventsFromHistoryVisibility(
+ ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) {
+
+ var eventNIDs []types.EventNID
+ for _, entry := range stateEntries {
+ // Filter the events to retrieve to only keep the membership events
+ if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID {
+ eventNIDs = append(eventNIDs, entry.EventNID)
+ break
+ }
+ }
+
+ // Get all of the events in this state
+ stateEvents, err := db.Events(ctx, eventNIDs)
+ if err != nil {
+ return nil, err
+ }
+ events := make([]gomatrixserverlib.Event, len(stateEvents))
+ for i := range stateEvents {
+ events[i] = stateEvents[i].Event
+ }
+ visibility := auth.HistoryVisibilityForRoom(events)
+ if visibility != "shared" {
+ logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility)
+ return nil, nil
+ }
+ // get joined members
+ roomNID, err := db.RoomNID(ctx, roomID)
+ if err != nil {
+ return nil, err
+ }
+ joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
+ if err != nil {
+ return nil, err
+ }
+ return db.Events(ctx, joinEventNIDs)
+}
diff --git a/roomserver/query/query.go b/roomserver/query/query.go
index a54fa58d..6778ac28 100644
--- a/roomserver/query/query.go
+++ b/roomserver/query/query.go
@@ -277,6 +277,7 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
response.JoinEvents = []gomatrixserverlib.ClientEvent{}
var events []types.Event
+ var stateEntries []types.StateEntry
if stillInRoom {
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
@@ -286,7 +287,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, eventNIDs)
} else {
- events, err = getMembershipsBeforeEventNID(ctx, r.DB, membershipEventNID, request.JoinedOnly)
+ stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID)
+ if err != nil {
+ logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
+ return err
+ }
+ events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
}
if err != nil {
@@ -301,15 +307,8 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
return nil
}
-// getMembershipsBeforeEventNID takes the numeric ID of an event and fetches the state
-// of the event's room as it was when this event was fired, then filters the state events to
-// only keep the "m.room.member" events with a "join" membership. These events are returned.
-// Returns an error if there was an issue fetching the events.
-func getMembershipsBeforeEventNID(
- ctx context.Context, db storage.Database, eventNID types.EventNID, joinedOnly bool,
-) ([]types.Event, error) {
+func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db)
- events := []types.Event{}
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
@@ -323,10 +322,15 @@ func getMembershipsBeforeEventNID(
}
// Fetch the state as it was when this event was fired
- stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, prevState)
- if err != nil {
- return nil, err
- }
+ return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
+}
+
+// getMembershipsAtState filters the state events to
+// only keep the "m.room.member" events with a "join" membership. These events are returned.
+// Returns an error if there was an issue fetching the events.
+func getMembershipsAtState(
+ ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
+) ([]types.Event, error) {
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
@@ -347,6 +351,7 @@ func getMembershipsBeforeEventNID(
}
// Filter the events to only keep the "join" membership events
+ var events []types.Event
for _, event := range stateEvents {
membership, err := event.Membership()
if err != nil {
@@ -563,20 +568,29 @@ func (r *RoomserverQueryAPI) backfillViaFederation(ctx context.Context, req *api
if !ok {
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
// which requires a list of state IDs.
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to find state IDs for event which passed auth checks")
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
continue
}
var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
- return err
+ // attempt to fetch the missing events
+ r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs)
+ // try again
+ entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
+ if err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
+ return err
+ }
}
var beforeStateSnapshotNID types.StateSnapshotNID
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
return err
}
+ util.GetLogger(ctx).Infof("Backfilled event %s (nid=%d) getting snapshot %v with entries %+v", ev.EventID(), ev.EventNID, beforeStateSnapshotNID, entries)
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to set state before event")
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
}
}
@@ -608,6 +622,66 @@ func (r *RoomserverQueryAPI) isServerCurrentlyInRoom(ctx context.Context, server
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
}
+// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
+// best effort.
+func (r *RoomserverQueryAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
+ backfillRequester *backfillRequester, stateIDs []string) {
+
+ servers := backfillRequester.servers
+
+ // work out which are missing
+ nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
+ return
+ }
+ missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
+ for _, id := range stateIDs {
+ if _, ok := nidMap[id]; !ok {
+ missingMap[id] = nil
+ }
+ }
+ util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
+
+ // fetch the events from federation. Loop the servers first so if we find one that works we stick with them
+ for _, srv := range servers {
+ for id, ev := range missingMap {
+ if ev != nil {
+ continue // already found
+ }
+ logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
+ res, err := r.FedClient.GetEvent(ctx, srv, id)
+ if err != nil {
+ logger.WithError(err).Warn("failed to get event from server")
+ continue
+ }
+ loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
+ result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
+ if err != nil {
+ logger.WithError(err).Warn("failed to load and verify event")
+ continue
+ }
+ logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
+ for _, res := range result {
+ if res.Error != nil {
+ logger.WithError(err).Warn("event failed PDU checks")
+ continue
+ }
+ missingMap[id] = res.Event
+ }
+ }
+ }
+
+ var newEvents []gomatrixserverlib.HeaderedEvent
+ for _, ev := range missingMap {
+ if ev != nil {
+ newEvents = append(newEvents, *ev)
+ }
+ }
+ util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
+ persistEvents(ctx, r.DB, newEvents)
+}
+
// TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo
func (r *RoomserverQueryAPI) scanEventTree(
@@ -857,7 +931,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
var stateAtEvent types.StateAtEvent
roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
if err != nil {
- logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to store backfilled event")
+ logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue
}
backfilledEventMap[ev.EventID()] = types.Event{
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index 389c9440..9b005ee6 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -86,7 +86,10 @@ func (v StateResolution) LoadStateAtEvent(
) ([]types.StateEntry, error) {
snapshotNID, err := v.db.SnapshotNIDFromEventID(ctx, eventID)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID failed for event %s : %s", eventID, err)
+ }
+ if snapshotNID == 0 {
+ return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
}
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index d881fa91..a63596ae 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -48,11 +48,6 @@ const insertEventSQL = `
ON CONFLICT DO NOTHING;
`
-const insertEventResultSQL = `
- SELECT event_nid, state_snapshot_nid FROM roomserver_events
- WHERE rowid = last_insert_rowid();
-`
-
const selectEventSQL = "" +
"SELECT event_nid, state_snapshot_nid FROM roomserver_events WHERE event_id = $1"
@@ -102,7 +97,6 @@ const selectRoomNIDForEventNIDSQL = "" +
type eventStatements struct {
db *sql.DB
insertEventStmt *sql.Stmt
- insertEventResultStmt *sql.Stmt
selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt
@@ -126,7 +120,6 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
return statementList{
{&s.insertEventStmt, insertEventSQL},
- {&s.insertEventResultStmt, insertEventResultSQL},
{&s.selectEventStmt, selectEventSQL},
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
@@ -152,19 +145,22 @@ func (s *eventStatements) insertEvent(
referenceSHA256 []byte,
authEventNIDs []types.EventNID,
depth int64,
-) (types.EventNID, types.StateSnapshotNID, error) {
- var eventNID int64
- var stateNID int64
- var err error
+) (types.EventNID, error) {
+ // attempt to insert: the last_row_id is the event NID
insertStmt := common.TxStmt(txn, s.insertEventStmt)
- resultStmt := common.TxStmt(txn, s.insertEventResultStmt)
- if _, err = insertStmt.ExecContext(
+ result, err := insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
- ); err == nil {
- err = resultStmt.QueryRowContext(ctx).Scan(&eventNID, &stateNID)
+ )
+ if err != nil {
+ return 0, err
}
- return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
+ modified, err := result.RowsAffected()
+ if modified == 0 && err == nil {
+ return 0, sql.ErrNoRows
+ }
+ eventNID, err := result.LastInsertId()
+ return types.EventNID(eventNID), err
}
func (s *eventStatements) selectEvent(
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 5df9c4e0..b6e846df 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -124,7 +124,7 @@ func (d *Database) StoreEvent(
}
}
- if eventNID, stateNID, err = d.statements.insertEvent(
+ if eventNID, err = d.statements.insertEvent(
ctx,
txn,
roomNID,