aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/postgres/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/postgres/storage.go')
-rw-r--r--roomserver/storage/postgres/storage.go106
1 files changed, 40 insertions, 66 deletions
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 6fcceced..0022c617 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -36,8 +36,10 @@ import (
type Database struct {
shared.Database
statements statements
+ events tables.Events
eventTypes tables.EventTypes
eventStateKeys tables.EventStateKeys
+ eventJSON tables.EventJSON
db *sql.DB
}
@@ -59,9 +61,19 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if err != nil {
return nil, err
}
+ d.eventJSON, err = NewPostgresEventJSONTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ d.events, err = NewPostgresEventsTable(d.db)
+ if err != nil {
+ return nil, err
+ }
d.Database = shared.Database{
EventTypesTable: d.eventTypes,
EventStateKeysTable: d.eventStateKeys,
+ EventJSONTable: d.eventJSON,
+ EventsTable: d.events,
}
return &d, nil
}
@@ -120,8 +132,9 @@ func (d *Database) StoreEvent(
}
}
- if eventNID, stateNID, err = d.statements.insertEvent(
+ if eventNID, stateNID, err = d.events.InsertEvent(
ctx,
+ nil,
roomNID,
eventTypeNID,
eventStateKeyNID,
@@ -132,14 +145,14 @@ func (d *Database) StoreEvent(
); err != nil {
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
- eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID())
+ eventNID, stateNID, err = d.events.SelectEvent(ctx, nil, event.EventID())
}
if err != nil {
return 0, types.StateAtEvent{}, err
}
}
- if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil {
+ if err = d.eventJSON.InsertEventJSON(ctx, nil, eventNID, event.JSON()); err != nil {
return 0, types.StateAtEvent{}, err
}
@@ -230,25 +243,11 @@ func (d *Database) assignStateKeyNID(
return eventStateKeyNID, err
}
-// StateEntriesForEventIDs implements input.EventDatabase
-func (d *Database) StateEntriesForEventIDs(
- ctx context.Context, eventIDs []string,
-) ([]types.StateEntry, error) {
- return d.statements.bulkSelectStateEventByID(ctx, eventIDs)
-}
-
-// EventNIDs implements query.RoomserverQueryAPIDatabase
-func (d *Database) EventNIDs(
- ctx context.Context, eventIDs []string,
-) (map[string]types.EventNID, error) {
- return d.statements.bulkSelectEventNID(ctx, eventIDs)
-}
-
// Events implements input.EventDatabase
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
- eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs)
+ eventJSONs, err := d.eventJSON.BulkSelectEventJSON(ctx, eventNIDs)
if err != nil {
return nil, err
}
@@ -258,7 +257,7 @@ func (d *Database) Events(
var roomVersion gomatrixserverlib.RoomVersion
result := &results[i]
result.EventNID = eventJSON.EventNID
- roomNID, err = d.statements.selectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID)
+ roomNID, err = d.events.SelectRoomNIDForEventNID(ctx, nil, eventJSON.EventNID)
if err != nil {
return nil, err
}
@@ -297,20 +296,6 @@ func (d *Database) AddState(
return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
}
-// SetState implements input.EventDatabase
-func (d *Database) SetState(
- ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
-) error {
- return d.statements.updateEventState(ctx, eventNID, stateNID)
-}
-
-// StateAtEventIDs implements input.EventDatabase
-func (d *Database) StateAtEventIDs(
- ctx context.Context, eventIDs []string,
-) ([]types.StateAtEvent, error) {
- return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs)
-}
-
// StateBlockNIDs implements state.RoomStateDatabase
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
@@ -325,21 +310,6 @@ func (d *Database) StateEntries(
return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
}
-// SnapshotNIDFromEventID implements state.RoomStateDatabase
-func (d *Database) SnapshotNIDFromEventID(
- ctx context.Context, eventID string,
-) (types.StateSnapshotNID, error) {
- _, stateNID, err := d.statements.selectEvent(ctx, eventID)
- return stateNID, err
-}
-
-// EventIDs implements input.RoomEventDatabase
-func (d *Database) EventIDs(
- ctx context.Context, eventNIDs []types.EventNID,
-) (map[types.EventNID]string, error) {
- return d.statements.bulkSelectEventID(ctx, eventNIDs)
-}
-
// GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID,
@@ -354,14 +324,14 @@ func (d *Database) GetLatestEventsForUpdate(
txn.Rollback() // nolint: errcheck
return nil, err
}
- stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
+ stateAndRefs, err := d.events.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
txn.Rollback() // nolint: errcheck
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
- lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
+ lastEventIDSent, err = d.events.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
txn.Rollback() // nolint: errcheck
return nil, err
@@ -450,12 +420,12 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
- return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID)
+ return u.d.events.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
- return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
+ return u.d.events.UpdateEventSentToOutput(u.ctx, u.txn, eventNID)
}
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (types.MembershipUpdater, error) {
@@ -491,20 +461,24 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
-) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
- eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID)
- if err != nil {
- return nil, 0, 0, err
- }
- references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs)
- if err != nil {
- return nil, 0, 0, err
- }
- depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs)
- if err != nil {
- return nil, 0, 0, err
- }
- return references, currentStateSnapshotNID, depth, nil
+) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
+ err = internal.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var eventNIDs []types.EventNID
+ eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, roomNID)
+ if err != nil {
+ return err
+ }
+ references, err = d.events.BulkSelectEventReference(ctx, txn, eventNIDs)
+ if err != nil {
+ return err
+ }
+ depth, err = d.events.SelectMaxEventDepth(ctx, txn, eventNIDs)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ return
}
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase