aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/shared/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/shared/storage.go')
-rw-r--r--roomserver/storage/shared/storage.go58
1 files changed, 55 insertions, 3 deletions
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 24b48772..096d5d7a 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -118,9 +118,24 @@ func (d *Database) StateEntriesForTuples(
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
- return d.StateBlockTable.BulkSelectFilteredStateBlockEntries(
- ctx, stateBlockNIDs, stateKeyTuples,
+ entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
+ ctx, stateBlockNIDs,
)
+ if err != nil {
+ return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
+ }
+ lists := []types.StateEntryList{}
+ for i, entry := range entries {
+ entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples)
+ if err != nil {
+ return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
+ }
+ lists = append(lists, types.StateEntryList{
+ StateBlockNID: stateBlockNIDs[i],
+ StateEntries: entries,
+ })
+ }
+ return lists, nil
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
@@ -141,8 +156,28 @@ func (d *Database) AddState(
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
+ if len(stateBlockNIDs) > 0 {
+ // Check to see if the event already appears in any of the existing state
+ // blocks. If it does then we should not add it again, as this will just
+ // result in excess state blocks and snapshots.
+ // TODO: Investigate why this is happening - probably input_events.go!
+ blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
+ if berr != nil {
+ return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
+ }
+ for i := len(state) - 1; i >= 0; i-- {
+ for _, events := range blocks {
+ for _, event := range events {
+ if state[i].EventNID == event {
+ state = append(state[:i], state[i+1:]...)
+ }
+ }
+ }
+ }
+ }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if len(state) > 0 {
+ // If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
if err != nil {
@@ -237,7 +272,24 @@ func (d *Database) StateBlockNIDs(
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
- return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
+ entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
+ ctx, stateBlockNIDs,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
+ }
+ lists := make([]types.StateEntryList, 0, len(entries))
+ for i, entry := range entries {
+ eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil)
+ if err != nil {
+ return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
+ }
+ lists = append(lists, types.StateEntryList{
+ StateBlockNID: stateBlockNIDs[i],
+ StateEntries: eventNIDs,
+ })
+ }
+ return lists, nil
}
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {