aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/postgres
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/postgres')
-rw-r--r--roomserver/storage/postgres/previous_events_table.go14
-rw-r--r--roomserver/storage/postgres/room_aliases_table.go20
-rw-r--r--roomserver/storage/postgres/sql.go4
-rw-r--r--roomserver/storage/postgres/state_block_table.go28
-rw-r--r--roomserver/storage/postgres/state_snapshot_table.go18
-rw-r--r--roomserver/storage/postgres/storage.go99
6 files changed, 70 insertions, 113 deletions
diff --git a/roomserver/storage/postgres/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go
index e3ad5dc8..b3d32c95 100644
--- a/roomserver/storage/postgres/previous_events_table.go
+++ b/roomserver/storage/postgres/previous_events_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
@@ -63,19 +64,20 @@ type previousEventStatements struct {
selectPreviousEventExistsStmt *sql.Stmt
}
-func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
- _, err = db.Exec(previousEventSchema)
+func NewPostgresPreviousEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
+ s := &previousEventStatements{}
+ _, err := db.Exec(previousEventSchema)
if err != nil {
- return
+ return nil, err
}
- return statementList{
+ return s, statementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.prepare(db)
}
-func (s *previousEventStatements) insertPreviousEvent(
+func (s *previousEventStatements) InsertPreviousEvent(
ctx context.Context,
txn *sql.Tx,
previousEventID string,
@@ -91,7 +93,7 @@ func (s *previousEventStatements) insertPreviousEvent(
// Check if the event reference exists
// Returns sql.ErrNoRows if the event reference doesn't exist.
-func (s *previousEventStatements) selectPreviousEventExists(
+func (s *previousEventStatements) SelectPreviousEventExists(
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
) error {
var ok int64
diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go
index c77edd0e..f869cf4f 100644
--- a/roomserver/storage/postgres/room_aliases_table.go
+++ b/roomserver/storage/postgres/room_aliases_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
)
const roomAliasesSchema = `
@@ -59,12 +60,13 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
-func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
- _, err = db.Exec(roomAliasesSchema)
+func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
+ s := &roomAliasesStatements{}
+ _, err := db.Exec(roomAliasesSchema)
if err != nil {
- return
+ return nil, err
}
- return statementList{
+ return s, statementList{
{&s.insertRoomAliasStmt, insertRoomAliasSQL},
{&s.selectRoomIDFromAliasStmt, selectRoomIDFromAliasSQL},
{&s.selectAliasesFromRoomIDStmt, selectAliasesFromRoomIDSQL},
@@ -73,14 +75,14 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
-func (s *roomAliasesStatements) insertRoomAlias(
+func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) {
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
return
}
-func (s *roomAliasesStatements) selectRoomIDFromAlias(
+func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string,
) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
@@ -90,7 +92,7 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias(
return
}
-func (s *roomAliasesStatements) selectAliasesFromRoomID(
+func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string,
) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
@@ -111,7 +113,7 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(
return aliases, rows.Err()
}
-func (s *roomAliasesStatements) selectCreatorIDFromAlias(
+func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string,
) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
@@ -121,7 +123,7 @@ func (s *roomAliasesStatements) selectCreatorIDFromAlias(
return
}
-func (s *roomAliasesStatements) deleteRoomAlias(
+func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string,
) (err error) {
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
diff --git a/roomserver/storage/postgres/sql.go b/roomserver/storage/postgres/sql.go
index 914f269c..eb626dd8 100644
--- a/roomserver/storage/postgres/sql.go
+++ b/roomserver/storage/postgres/sql.go
@@ -38,10 +38,6 @@ func (s *statements) prepare(db *sql.DB) error {
var err error
for _, prepare := range []func(db *sql.DB) error{
- s.stateSnapshotStatements.prepare,
- s.stateBlockStatements.prepare,
- s.previousEventStatements.prepare,
- s.roomAliasesStatements.prepare,
s.inviteStatements.prepare,
s.membershipStatements.prepare,
} {
diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go
index 38334fa9..d1aaaa00 100644
--- a/roomserver/storage/postgres/state_block_table.go
+++ b/roomserver/storage/postgres/state_block_table.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/lib/pq"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
)
@@ -87,13 +88,14 @@ type stateBlockStatements struct {
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
}
-func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
- _, err = db.Exec(stateDataSchema)
+func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
+ s := &stateBlockStatements{}
+ _, err := db.Exec(stateDataSchema)
if err != nil {
- return
+ return nil, err
}
- return statementList{
+ return s, statementList{
{&s.insertStateDataStmt, insertStateDataSQL},
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
@@ -101,11 +103,15 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
}.prepare(db)
}
-func (s *stateBlockStatements) bulkInsertStateData(
+func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context,
- stateBlockNID types.StateBlockNID,
+ txn *sql.Tx,
entries []types.StateEntry,
-) error {
+) (types.StateBlockNID, error) {
+ stateBlockNID, err := s.selectNextStateBlockNID(ctx)
+ if err != nil {
+ return 0, err
+ }
for _, entry := range entries {
_, err := s.insertStateDataStmt.ExecContext(
ctx,
@@ -115,10 +121,10 @@ func (s *stateBlockStatements) bulkInsertStateData(
int64(entry.EventNID),
)
if err != nil {
- return err
+ return 0, err
}
}
- return nil
+ return stateBlockNID, nil
}
func (s *stateBlockStatements) selectNextStateBlockNID(
@@ -129,7 +135,7 @@ func (s *stateBlockStatements) selectNextStateBlockNID(
return types.StateBlockNID(stateBlockNID), err
}
-func (s *stateBlockStatements) bulkSelectStateBlockEntries(
+func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs))
@@ -180,7 +186,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
return results, err
}
-func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
+func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go
index a1f26e22..8971292f 100644
--- a/roomserver/storage/postgres/state_snapshot_table.go
+++ b/roomserver/storage/postgres/state_snapshot_table.go
@@ -21,6 +21,7 @@ import (
"fmt"
"github.com/lib/pq"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
@@ -64,30 +65,31 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
-func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
- _, err = db.Exec(stateSnapshotSchema)
+func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
+ s := &stateSnapshotStatements{}
+ _, err := db.Exec(stateSnapshotSchema)
if err != nil {
- return
+ return nil, err
}
- return statementList{
+ return s, statementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.prepare(db)
}
-func (s *stateSnapshotStatements) insertState(
- ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
+func (s *stateSnapshotStatements) InsertState(
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
- err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
+ err = txn.Stmt(s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return
}
-func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
+func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index d44da858..03cfb7f0 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -40,10 +40,12 @@ type Database struct {
eventJSON tables.EventJSON
rooms tables.Rooms
transactions tables.Transactions
+ prevEvents tables.PreviousEvents
db *sql.DB
}
// Open a postgres database.
+// nolint: gocyclo
func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database, error) {
var d Database
var err error
@@ -77,6 +79,22 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
if err != nil {
return nil, err
}
+ stateBlock, err := NewPostgresStateBlockTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ stateSnapshot, err := NewPostgresStateSnapshotTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ roomAliases, err := NewPostgresRoomAliasesTable(d.db)
+ if err != nil {
+ return nil, err
+ }
+ d.prevEvents, err = NewPostgresPreviousEventsTable(d.db)
+ if err != nil {
+ return nil, err
+ }
d.Database = shared.Database{
DB: d.db,
EventTypesTable: d.eventTypes,
@@ -85,6 +103,10 @@ func Open(dataSourceName string, dbProperties internal.DbProperties) (*Database,
EventsTable: d.events,
RoomsTable: d.rooms,
TransactionsTable: d.transactions,
+ StateBlockTable: stateBlock,
+ StateSnapshotTable: stateSnapshot,
+ PrevEventsTable: d.prevEvents,
+ RoomAliasesTable: roomAliases,
}
return &d, nil
}
@@ -122,41 +144,6 @@ func (d *Database) assignStateKeyNID(
return eventStateKeyNID, err
}
-// AddState implements input.EventDatabase
-func (d *Database) AddState(
- ctx context.Context,
- roomNID types.RoomNID,
- stateBlockNIDs []types.StateBlockNID,
- state []types.StateEntry,
-) (types.StateSnapshotNID, error) {
- if len(state) > 0 {
- stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx)
- if err != nil {
- return 0, err
- }
- if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil {
- return 0, err
- }
- stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
- }
-
- return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
-}
-
-// StateBlockNIDs implements state.RoomStateDatabase
-func (d *Database) StateBlockNIDs(
- ctx context.Context, stateNIDs []types.StateSnapshotNID,
-) ([]types.StateBlockNIDList, error) {
- return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs)
-}
-
-// StateEntries implements state.RoomStateDatabase
-func (d *Database) StateEntries(
- ctx context.Context, stateBlockNIDs []types.StateBlockNID,
-) ([]types.StateEntryList, error) {
- return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
-}
-
// GetLatestEventsForUpdate implements input.EventDatabase
func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomNID types.RoomNID,
@@ -222,7 +209,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
// StorePreviousEvents implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences {
- if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
+ if err := u.d.prevEvents.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return err
}
}
@@ -231,7 +218,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
- err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
+ err := u.d.prevEvents.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
@@ -276,44 +263,6 @@ func (d *Database) GetInvitesForUser(
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
}
-// SetRoomAlias implements alias.RoomserverAliasAPIDB
-func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
- return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID)
-}
-
-// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
-func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
- return d.statements.selectRoomIDFromAlias(ctx, alias)
-}
-
-// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
-func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
- return d.statements.selectAliasesFromRoomID(ctx, roomID)
-}
-
-// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
-func (d *Database) GetCreatorIDForAlias(
- ctx context.Context, alias string,
-) (string, error) {
- return d.statements.selectCreatorIDFromAlias(ctx, alias)
-}
-
-// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
-func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
- return d.statements.deleteRoomAlias(ctx, alias)
-}
-
-// StateEntriesForTuples implements state.RoomStateDatabase
-func (d *Database) StateEntriesForTuples(
- ctx context.Context,
- stateBlockNIDs []types.StateBlockNID,
- stateKeyTuples []types.StateKeyTuple,
-) ([]types.StateEntryList, error) {
- return d.statements.bulkSelectFilteredStateBlockEntries(
- ctx, stateBlockNIDs, stateKeyTuples,
- )
-}
-
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,