From 9ed68a3125d9024f52bf89810abf3b203f4b25b7 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 14 May 2020 09:53:55 +0100 Subject: Factor out account data and events table (#1031) * Factor out account data * Factor out events table and EDU cache * linting * fix npe --- syncapi/storage/postgres/account_data_table.go | 24 +-- .../storage/postgres/output_room_events_table.go | 36 +++-- syncapi/storage/postgres/syncserver.go | 180 +++------------------ 3 files changed, 58 insertions(+), 182 deletions(-) (limited to 'syncapi/storage/postgres') diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index d1e3b527..58fb2198 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -70,32 +71,33 @@ type accountDataStatements struct { selectMaxAccountDataIDStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(accountDataSchema) +func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountData, error) { + s := &accountDataStatements{} + _, err := db.Exec(accountDataSchema) if err != nil { - return + return nil, err } if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return + return nil, err } if s.selectAccountDataInRangeStmt, err = db.Prepare(selectAccountDataInRangeSQL); err != nil { - return + return nil, err } if s.selectMaxAccountDataIDStmt, err = db.Prepare(selectMaxAccountDataIDSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *accountDataStatements) insertAccountData( - ctx context.Context, +func (s *accountDataStatements) InsertAccountData( + ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { err = s.insertAccountDataStmt.QueryRowContext(ctx, userID, roomID, dataType).Scan(&pos) return } -func (s *accountDataStatements) selectAccountDataInRange( +func (s *accountDataStatements) SelectAccountDataInRange( ctx context.Context, userID string, oldPos, newPos types.StreamPosition, @@ -137,7 +139,7 @@ func (s *accountDataStatements) selectAccountDataInRange( return data, rows.Err() } -func (s *accountDataStatements) selectMaxAccountDataID( +func (s *accountDataStatements) SelectMaxAccountDataID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 0b53dfa9..5870bfd5 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -22,6 +22,7 @@ import ( "sort" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/lib/pq" @@ -120,39 +121,40 @@ type outputRoomEventsStatements struct { selectStateInRangeStmt *sql.Stmt } -func (s *outputRoomEventsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(outputRoomEventsSchema) +func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { + s := &outputRoomEventsStatements{} + _, err := db.Exec(outputRoomEventsSchema) if err != nil { - return + return nil, err } if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { - return + return nil, err } if s.selectEventsStmt, err = db.Prepare(selectEventsSQL); err != nil { - return + return nil, err } if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil { - return + return nil, err } if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil { - return + return nil, err } if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil { - return + return nil, err } if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil { - return + return nil, err } if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil { - return + return nil, err } - return + return s, nil } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // two positions, only the most recent state is returned. -func (s *outputRoomEventsStatements) selectStateInRange( +func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, stateFilter *gomatrixserverlib.StateFilter, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { @@ -233,7 +235,7 @@ func (s *outputRoomEventsStatements) selectStateInRange( // MaxID returns the ID of the last inserted event in this table. 'txn' is optional. If it is not supplied, // then this function should only ever be used at startup, as it will race with inserting events if it is // done afterwards. If there are no inserted events, 0 is returned. -func (s *outputRoomEventsStatements) selectMaxEventID( +func (s *outputRoomEventsStatements) SelectMaxEventID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 @@ -247,7 +249,7 @@ func (s *outputRoomEventsStatements) selectMaxEventID( // InsertEvent into the output_room_events table. addState and removeState are an optional list of state event IDs. Returns the position // of the inserted event. -func (s *outputRoomEventsStatements) insertEvent( +func (s *outputRoomEventsStatements) InsertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, @@ -294,7 +296,7 @@ func (s *outputRoomEventsStatements) insertEvent( // selectRecentEvents returns the most recent events in the given room, up to a maximum of 'limit'. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude // from sync. -func (s *outputRoomEventsStatements) selectRecentEvents( +func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int, chronologicalOrder bool, onlySyncEvents bool, @@ -327,7 +329,7 @@ func (s *outputRoomEventsStatements) selectRecentEvents( // selectEarlyEvents returns the earliest events in the given room, starting // from a given position, up to a maximum of 'limit'. -func (s *outputRoomEventsStatements) selectEarlyEvents( +func (s *outputRoomEventsStatements) SelectEarlyEvents( ctx context.Context, txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int, ) ([]types.StreamEvent, error) { @@ -352,7 +354,7 @@ func (s *outputRoomEventsStatements) selectEarlyEvents( // selectEvents returns the events for the given event IDs. If an event is // missing from the database, it will be omitted. -func (s *outputRoomEventsStatements) selectEvents( +func (s *outputRoomEventsStatements) SelectEvents( ctx context.Context, txn *sql.Tx, eventIDs []string, ) ([]types.StreamEvent, error) { stmt := common.TxStmt(txn, s.selectEventsStmt) diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 9883c362..4fa08ce5 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -20,9 +20,6 @@ import ( "database/sql" "encoding/json" "fmt" - "time" - - "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -50,15 +47,12 @@ type stateDelta struct { // SyncServerDatasource represents a sync server datasource which manages // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { + shared.Database db *sql.DB common.PartitionOffsetStatements - accountData accountDataStatements - events outputRoomEventsStatements roomstate currentRoomStateStatements - eduCache *cache.EDUCache topology outputRoomEventsTopologyStatements backwardExtremities tables.BackwardsExtremities - shared *shared.Database } // NewSyncServerDatasource creates a new sync server database @@ -71,10 +65,12 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { return nil, err } - if err = d.accountData.prepare(d.db); err != nil { + accountData, err := NewPostgresAccountDataTable(d.db) + if err != nil { return nil, err } - if err = d.events.prepare(d.db); err != nil { + events, err := NewPostgresEventsTable(d.db) + if err != nil { return nil, err } if err = d.roomstate.prepare(d.db); err != nil { @@ -91,10 +87,12 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp if err != nil { return nil, err } - d.eduCache = cache.New() - d.shared = &shared.Database{ - DB: d.db, - Invites: invites, + d.Database = shared.Database{ + DB: d.db, + Invites: invites, + AccountData: accountData, + OutputEvents: events, + EDUCache: cache.New(), } return &d, nil } @@ -103,17 +101,6 @@ func (d *SyncServerDatasource) AllJoinedUsersInRooms(ctx context.Context) (map[s return d.roomstate.selectJoinedUsers(ctx) } -func (d *SyncServerDatasource) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) { - streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) - if err != nil { - return nil, err - } - - // We don't include a device here as we only include transaction IDs in - // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil -} - // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. @@ -124,7 +111,7 @@ func (d *SyncServerDatasource) handleBackwardExtremities(ctx context.Context, tx // Check if we have all of the event's previous events. If an event is // missing, add it to the room's backward extremities. - prevEvents, err := d.events.selectEvents(ctx, txn, ev.PrevEventIDs()) + prevEvents, err := d.Database.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs()) if err != nil { return err } @@ -157,7 +144,7 @@ func (d *SyncServerDatasource) WriteEvent( ) (pduPosition types.StreamPosition, returnErr error) { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - pos, err := d.events.insertEvent( + pos, err := d.Database.OutputEvents.InsertEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ) if err != nil { @@ -265,36 +252,10 @@ func (d *SyncServerDatasource) GetEventsInTopologicalRange( } // Retrieve the events' contents using their IDs. - events, err = d.events.selectEvents(ctx, nil, eIDs) + events, err = d.Database.OutputEvents.SelectEvents(ctx, nil, eIDs) return } -// GetEventsInStreamingRange retrieves all of the events on a given ordering using the -// given extremities and limit. -func (d *SyncServerDatasource) GetEventsInStreamingRange( - ctx context.Context, - from, to *types.StreamingToken, - roomID string, limit int, - backwardOrdering bool, -) (events []types.StreamEvent, err error) { - if backwardOrdering { - // When using backward ordering, we want the most recent events first. - if events, err = d.events.selectRecentEvents( - ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false, - ); err != nil { - return - } - } else { - // When using forward ordering, we want the least recent events first. - if events, err = d.events.selectEarlyEvents( - ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit, - ); err != nil { - return - } - } - return events, err -} - func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.StreamingToken, error) { return d.syncPositionTx(ctx, nil) } @@ -319,7 +280,7 @@ func (d *SyncServerDatasource) EventsAtTopologicalPosition( return nil, err } - return d.events.selectEvents(ctx, nil, eIDs) + return d.Database.OutputEvents.SelectEvents(ctx, nil, eIDs) } func (d *SyncServerDatasource) EventPositionInTopology( @@ -328,57 +289,29 @@ func (d *SyncServerDatasource) EventPositionInTopology( return d.topology.selectPositionInTopology(ctx, eventID) } -func (d *SyncServerDatasource) SyncStreamPosition(ctx context.Context) (types.StreamPosition, error) { - return d.syncStreamPositionTx(ctx, nil) -} - -func (d *SyncServerDatasource) syncStreamPositionTx( - ctx context.Context, txn *sql.Tx, -) (types.StreamPosition, error) { - maxID, err := d.events.selectMaxEventID(ctx, txn) - if err != nil { - return 0, err - } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) - if err != nil { - return 0, err - } - if maxAccountDataID > maxID { - maxID = maxAccountDataID - } - maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) - if err != nil { - return 0, err - } - if maxInviteID > maxID { - maxID = maxInviteID - } - return types.StreamPosition(maxID), nil -} - func (d *SyncServerDatasource) syncPositionTx( ctx context.Context, txn *sql.Tx, ) (sp types.StreamingToken, err error) { - maxEventID, err := d.events.selectMaxEventID(ctx, txn) + maxEventID, err := d.Database.OutputEvents.SelectMaxEventID(ctx, txn) if err != nil { return sp, err } - maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) + maxAccountDataID, err := d.Database.AccountData.SelectMaxAccountDataID(ctx, txn) if err != nil { return sp, err } if maxAccountDataID > maxEventID { maxEventID = maxAccountDataID } - maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) + maxInviteID, err := d.Database.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return sp, err } if maxInviteID > maxEventID { maxEventID = maxInviteID } - sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.eduCache.GetLatestSyncPosition())) + sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.Database.EDUCache.GetLatestSyncPosition())) return } @@ -451,7 +384,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse( var ok bool var err error for _, roomID := range joinedRoomIDs { - if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( + if typingUsers, updated := d.Database.EDUCache.GetTypingUsersIfUpdatedAfter( roomID, int64(since.EDUPosition()), ); updated { ev := gomatrixserverlib.ClientEvent{ @@ -580,7 +513,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync( // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 var recentStreamEvents []types.StreamEvent - recentStreamEvents, err = d.events.selectRecentEvents( + recentStreamEvents, err = d.Database.OutputEvents.SelectRecentEvents( ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(), numRecentEventsPerRoom, true, true, ) @@ -653,54 +586,13 @@ var txReadOnlySnapshot = sql.TxOptions{ ReadOnly: true, } -func (d *SyncServerDatasource) GetAccountDataInRange( - ctx context.Context, userID string, oldPos, newPos types.StreamPosition, - accountDataFilterPart *gomatrixserverlib.EventFilter, -) (map[string][]string, error) { - return d.accountData.selectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) -} - -func (d *SyncServerDatasource) UpsertAccountData( - ctx context.Context, userID, roomID, dataType string, -) (types.StreamPosition, error) { - return d.accountData.insertAccountData(ctx, userID, roomID, dataType) -} - -func (d *SyncServerDatasource) AddInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, -) (sp types.StreamPosition, err error) { - return d.shared.AddInviteEvent(ctx, inviteEvent) -} - -func (d *SyncServerDatasource) RetireInviteEvent( - ctx context.Context, inviteEventID string, -) error { - return d.shared.RetireInviteEvent(ctx, inviteEventID) -} - -func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { - d.eduCache.SetTimeoutCallback(fn) -} - -func (d *SyncServerDatasource) AddTypingUser( - userID, roomID string, expireTime *time.Time, -) types.StreamPosition { - return types.StreamPosition(d.eduCache.AddTypingUser(userID, roomID, expireTime)) -} - -func (d *SyncServerDatasource) RemoveTypingUser( - userID, roomID string, -) types.StreamPosition { - return types.StreamPosition(d.eduCache.RemoveUser(userID, roomID)) -} - func (d *SyncServerDatasource) addInvitesToResponse( ctx context.Context, txn *sql.Tx, userID string, fromPos, toPos types.StreamPosition, res *types.Response, ) error { - invites, err := d.shared.Invites.SelectInviteEventsInRange( + invites, err := d.Database.Invites.SelectInviteEventsInRange( ctx, txn, userID, fromPos, toPos, ) if err != nil { @@ -750,7 +642,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse( // This is all "okay" assuming history_visibility == "shared" which it is by default. endPos = delta.membershipPos } - recentStreamEvents, err := d.events.selectRecentEvents( + recentStreamEvents, err := d.Database.OutputEvents.SelectRecentEvents( ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos), numRecentEventsPerRoom, true, true, ) @@ -841,7 +733,7 @@ func (d *SyncServerDatasource) fetchMissingStateEvents( ) ([]types.StreamEvent, error) { // Fetch from the events table first so we pick up the stream ID for the // event. - events, err := d.events.selectEvents(ctx, txn, eventIDs) + events, err := d.Database.OutputEvents.SelectEvents(ctx, txn, eventIDs) if err != nil { return nil, err } @@ -895,7 +787,7 @@ func (d *SyncServerDatasource) getStateDeltas( var deltas []stateDelta // get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilter) + stateNeeded, eventMap, err := d.Database.OutputEvents.SelectStateInRange(ctx, txn, fromPos, toPos, stateFilter) if err != nil { return nil, nil, err } @@ -981,7 +873,7 @@ func (d *SyncServerDatasource) getStateDeltasForFullStateSync( } // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.events.selectStateInRange(ctx, txn, fromPos, toPos, stateFilter) + stateNeeded, eventMap, err := d.Database.OutputEvents.SelectStateInRange(ctx, txn, fromPos, toPos, stateFilter) if err != nil { return nil, nil, err } @@ -1025,26 +917,6 @@ func (d *SyncServerDatasource) currentStateStreamEventsForRoom( return s, nil } -func (d *SyncServerDatasource) StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent { - out := make([]gomatrixserverlib.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - // There may be some overlap where events in stateEvents are already in recentEvents, so filter // them out so we don't include them twice in the /sync response. They should be in recentEvents // only, so clients get to the correct state once they have rolled forward. -- cgit v1.2.3