diff options
Diffstat (limited to 'roomserver/storage/postgres/storage.go')
-rw-r--r-- | roomserver/storage/postgres/storage.go | 106 |
1 files changed, 40 insertions, 66 deletions
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 6fcceced..0022c617 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -36,8 +36,10 @@ import ( type Database struct { shared.Database statements statements + events tables.Events eventTypes tables.EventTypes eventStateKeys tables.EventStateKeys + eventJSON tables.EventJSON db *sql.DB } @@ -59,9 +61,19 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, if err != nil { return nil, err } + d.eventJSON, err = NewPostgresEventJSONTable(d.db) + if err != nil { + return nil, err + } + d.events, err = NewPostgresEventsTable(d.db) + if err != nil { + return nil, err + } d.Database = shared.Database{ EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, + EventJSONTable: d.eventJSON, + EventsTable: d.events, } return &d, nil } @@ -120,8 +132,9 @@ func (d *Database) StoreEvent( } } - if eventNID, stateNID, err = d.statements.insertEvent( + if eventNID, stateNID, err = d.events.InsertEvent( ctx, + nil, roomNID, eventTypeNID, eventStateKeyNID, @@ -132,14 +145,14 @@ func (d *Database) StoreEvent( ); err != nil { if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + eventNID, stateNID, err = d.events.SelectEvent(ctx, nil, event.EventID()) } if err != nil { return 0, types.StateAtEvent{}, err } } - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + if err = d.eventJSON.InsertEventJSON(ctx, nil, eventNID, event.JSON()); err != nil { return 0, types.StateAtEvent{}, err } @@ -230,25 +243,11 @@ func (d *Database) assignStateKeyNID( return eventStateKeyNID, err } -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, eventIDs) -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, eventIDs) -} - // Events implements input.EventDatabase func (d *Database) Events( ctx context.Context, eventNIDs []types.EventNID, ) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) + eventJSONs, err := d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs) if err != nil { return nil, err } @@ -258,7 +257,7 @@ func (d *Database) Events( var roomVersion gomatrixserverlib.RoomVersion result := &results[i] result.EventNID = eventJSON.EventNID - roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) + roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID) if err != nil { return nil, err } @@ -297,20 +296,6 @@ func (d *Database) AddState( return d.statements.insertState(ctx, roomNID, stateBlockNIDs) } -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - return d.statements.updateEventState(ctx, eventNID, stateNID) -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) -} - // StateBlockNIDs implements state.RoomStateDatabase func (d *Database) StateBlockNIDs( ctx context.Context, stateNIDs []types.StateSnapshotNID, @@ -325,21 +310,6 @@ func (d *Database) StateEntries( return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) } -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) - return stateNID, err -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, eventNIDs) -} - // GetLatestEventsForUpdate implements input.EventDatabase func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, @@ -354,14 +324,14 @@ func (d *Database) GetLatestEventsForUpdate( txn.Rollback() // nolint: errcheck return nil, err } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + stateAndRefs, err := d.events.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) if err != nil { txn.Rollback() // nolint: errcheck return nil, err } var lastEventIDSent string if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + lastEventIDSent, err = d.events.SelectEventID(ctx, txn, lastEventNIDSent) if err != nil { txn.Rollback() // nolint: errcheck return nil, err @@ -450,12 +420,12 @@ func (u *roomRecentEventsUpdater) SetLatestEvents( // HasEventBeenSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.events.SelectEventSentToOutput(u.ctx, u.txn, eventNID) } // MarkEventAsSent implements types.RoomRecentEventsUpdater func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) + return u.d.events.UpdateEventSentToOutput(u.ctx, u.txn, eventNID) } func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) { @@ -491,20 +461,24 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro // LatestEventIDs implements query.RoomserverQueryAPIDatabase func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, -) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return nil, 0, 0, err - } - references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - return references, currentStateSnapshotNID, depth, nil +) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { + err = internal.WithTransaction(d.db, func(txn *sql.Tx) error { + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, roomNID) + if err != nil { + return err + } + references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs) + if err != nil { + return err + } + depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs) + if err != nil { + return err + } + return nil + }) + return } // GetInvitesForUser implements query.RoomserverQueryAPIDatabase |