diff options
author | Kegsay <kegan@matrix.org> | 2020-05-26 16:45:28 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-26 16:45:28 +0100 |
commit | 803af87dc49f6db57019892215c6c6cf049b5c50 (patch) | |
tree | 0b2a130846fdae263ab2e0641fdbed711eea1163 /roomserver/storage/postgres | |
parent | 737c83e0ae496449327ef596811e984c2752e39b (diff) |
Convert events/event_json tables to share code (#1062)
* Convert event_json table
* Convert the events table
Diffstat (limited to 'roomserver/storage/postgres')
-rw-r--r-- | roomserver/storage/postgres/event_json_table.go | 25 | ||||
-rw-r--r-- | roomserver/storage/postgres/events_table.go | 43 | ||||
-rw-r--r-- | roomserver/storage/postgres/sql.go | 2 | ||||
-rw-r--r-- | roomserver/storage/postgres/storage.go | 106 |
4 files changed, 74 insertions, 102 deletions
diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 661c4472..a3262926 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -58,32 +59,28 @@ type eventJSONStatements struct { bulkSelectEventJSONStmt *sql.Stmt } -func (s *eventJSONStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventJSONSchema) +func NewPostgresEventJSONTable(db *sql.DB) (tables.EventJSON, error) { + s := &eventJSONStatements{} + _, err := db.Exec(eventJSONSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventJSONStmt, insertEventJSONSQL}, {&s.bulkSelectEventJSONStmt, bulkSelectEventJSONSQL}, }.prepare(db) } -func (s *eventJSONStatements) insertEventJSON( - ctx context.Context, eventNID types.EventNID, eventJSON []byte, +func (s *eventJSONStatements) InsertEventJSON( + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } -type eventJSONPair struct { - EventNID types.EventNID - EventJSON []byte -} - -func (s *eventJSONStatements) bulkSelectEventJSON( +func (s *eventJSONStatements) BulkSelectEventJSON( ctx context.Context, eventNIDs []types.EventNID, -) ([]eventJSONPair, error) { +) ([]tables.EventJSONPair, error) { rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err @@ -94,7 +91,7 @@ func (s *eventJSONStatements) bulkSelectEventJSON( // because of the unique constraint on event NIDs. // So we can allocate an array of the correct size now. // We might get fewer results than NIDs so we adjust the length of the slice before returning it. - results := make([]eventJSONPair, len(eventNIDs)) + results := make([]tables.EventJSONPair, len(eventNIDs)) i := 0 for ; rows.Next(); i++ { result := &results[i] diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index c28fa8e6..9c464946 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -22,6 +22,7 @@ import ( "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -136,13 +137,14 @@ type eventStatements struct { selectRoomNIDForEventNIDStmt *sql.Stmt } -func (s *eventStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(eventsSchema) +func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { + s := &eventStatements{} + _, err := db.Exec(eventsSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, statementList{ {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, @@ -160,8 +162,9 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { }.prepare(db) } -func (s *eventStatements) insertEvent( +func (s *eventStatements) InsertEvent( ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, @@ -179,8 +182,8 @@ func (s *eventStatements) insertEvent( return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } -func (s *eventStatements) selectEvent( - ctx context.Context, eventID string, +func (s *eventStatements) SelectEvent( + ctx context.Context, txn *sql.Tx, eventID string, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 @@ -190,7 +193,7 @@ func (s *eventStatements) selectEvent( // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError -func (s *eventStatements) bulkSelectStateEventByID( +func (s *eventStatements) BulkSelectStateEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateEntry, error) { rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -233,7 +236,7 @@ func (s *eventStatements) bulkSelectStateEventByID( // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. -func (s *eventStatements) bulkSelectStateAtEventByID( +func (s *eventStatements) BulkSelectStateAtEventByID( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) @@ -270,14 +273,14 @@ func (s *eventStatements) bulkSelectStateAtEventByID( return results, nil } -func (s *eventStatements) updateEventState( +func (s *eventStatements) UpdateEventState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) return err } -func (s *eventStatements) selectEventSentToOutput( +func (s *eventStatements) SelectEventSentToOutput( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (sentToOutput bool, err error) { stmt := internal.TxStmt(txn, s.selectEventSentToOutputStmt) @@ -285,13 +288,13 @@ func (s *eventStatements) selectEventSentToOutput( return } -func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { +func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error { stmt := internal.TxStmt(txn, s.updateEventSentToOutputStmt) _, err := stmt.ExecContext(ctx, int64(eventNID)) return err } -func (s *eventStatements) selectEventID( +func (s *eventStatements) SelectEventID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (eventID string, err error) { stmt := internal.TxStmt(txn, s.selectEventIDStmt) @@ -299,7 +302,7 @@ func (s *eventStatements) selectEventID( return } -func (s *eventStatements) bulkSelectStateAtEventAndReference( +func (s *eventStatements) BulkSelectStateAtEventAndReference( ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]types.StateAtEventAndReference, error) { stmt := internal.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt) @@ -341,8 +344,8 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference( return results, nil } -func (s *eventStatements) bulkSelectEventReference( - ctx context.Context, eventNIDs []types.EventNID, +func (s *eventStatements) BulkSelectEventReference( + ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, ) ([]gomatrixserverlib.EventReference, error) { rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { @@ -367,7 +370,7 @@ func (s *eventStatements) bulkSelectEventReference( } // bulkSelectEventID returns a map from numeric event ID to string event ID. -func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { +func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) { rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) if err != nil { return nil, err @@ -394,7 +397,7 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ // bulkSelectEventNIDs returns a map from string event ID to numeric event ID. // If an event ID is not in the database then it is omitted from the map. -func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { +func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) { rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs)) if err != nil { return nil, err @@ -412,7 +415,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str return results, rows.Err() } -func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) { +func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) { var result int64 stmt := s.selectMaxEventDepthStmt err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result) @@ -422,7 +425,7 @@ func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []t return result, nil } -func (s *eventStatements) selectRoomNIDForEventNID( +func (s *eventStatements) SelectRoomNIDForEventNID( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, ) (roomNID types.RoomNID, err error) { selectStmt := internal.TxStmt(txn, s.selectRoomNIDForEventNIDStmt) diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go index e41c5a39..964dabbb 100644 --- a/roomserver/storage/postgres/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -39,8 +39,6 @@ func (s *statements) prepare(db *sql.DB) error { for _, prepare := range []func(db *sql.DB) error{ s.roomStatements.prepare, - s.eventStatements.prepare, - s.eventJSONStatements.prepare, s.stateSnapshotStatements.prepare, s.stateBlockStatements.prepare, s.previousEventStatements.prepare, diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 6fcceced..0022c617 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -36,8 +36,10 @@ import ( type Database struct { shared.Database statements statements + events tables.Events eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys + eventJSON tables.EventJSON db *sql.DB } @@ -59,9 +61,19 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.eventJSON, err = NewPostgresEventJSONTable(d.db) + if err != nil { + return nil, err + } + d.events, err = NewPostgresEventsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, + EventJSONTable: d.eventJSON, + EventsTable: d.events, } return &d, nil } @@ -120,8 +132,9 @@ func (d *Database) StoreEvent( } } - if eventNID, stateNID, err = d.statements.insertEvent( + if eventNID, stateNID, err = d.events.InsertEvent( ctx, + nil, roomNID, eventTypeNID, eventStateKeyNID, @@ -132,14 +145,14 @@ func (d *Database) StoreEvent( ); err != nil { if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + eventNID, stateNID, err = d.events.SelectEvent(ctx, nil, event.EventID()) } if err != nil { return 0, types.StateAtEvent{}, err } } - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + if err = d.eventJSON.InsertEventJSON(ctx, nil, eventNID, event.JSON()); err != nil { return 0, types.StateAtEvent{}, err } @@ -230,25 +243,11 @@ func (d *Database) assignStateKeyNID( return eventStateKeyNID, err } -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, eventIDs) -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, eventIDs) -} - // Events implements input.EventDatabase func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) + eventJSONs, err := d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) if err != nil { return nil, err } @@ -258,7 +257,7 @@ func (d *Database) Events( var roomVersion gomatrixserverlib.RoomVersion result := &results[i] result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) + roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) if err != nil { return nil, err } @@ -297,20 +296,6 @@ func (d *Database) AddState( return d.statements.insertState(ctx, roomNID, stateBlockNIDs) } -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - return d.statements.updateEventState(ctx, eventNID, stateNID) -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) -} - // StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, @@ -325,21 +310,6 @@ func (d *Database) StateEntries( return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) } -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) - return stateNID, err -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, eventNIDs) -} - // GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, @@ -354,14 +324,14 @@ func (d *Database) GetLatestEventsForUpdate( txn.Rollback() // nolint: errcheck return nil, err } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + stateAndRefs, err := d.events.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) if err != nil { txn.Rollback() // nolint: errcheck return nil, err } var lastEventIDSent string if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + lastEventIDSent, err = d.events.SelectEventID(ctx, txn, lastEventNIDSent) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -450,12 +420,12 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( // HasEventBeenSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.events.SelectEventSentToOutput(u.ctx, u.txn, eventNID) } // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.events.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) } func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { @@ -491,20 +461,24 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro // LatestEventIDs implements query.RoomserverQueryAPIDatabase func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, -) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return nil, 0, 0, err - } - references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - return references, currentStateSnapshotNID, depth, nil +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, roomNID) + if err != nil { + return err + } + references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return } // GetInvitesForUser implements query.RoomserverQueryAPIDatabase |