diff options
Diffstat (limited to 'roomserver/storage/shared/storage.go')
-rw-r--r-- | roomserver/storage/shared/storage.go | 58 |
1 files changed, 55 insertions, 3 deletions
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 24b48772..096d5d7a 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -118,9 +118,24 @@ func (d *Database) StateEntriesForTuples( stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := []types.StateEntryList{} + for i, entry := range entries { + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: entries, + }) + } + return lists, nil } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { @@ -141,8 +156,28 @@ func (d *Database) AddState( stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { + if len(stateBlockNIDs) > 0 { + // Check to see if the event already appears in any of the existing state + // blocks. If it does then we should not add it again, as this will just + // result in excess state blocks and snapshots. + // TODO: Investigate why this is happening - probably input_events.go! + blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + if berr != nil { + return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr) + } + for i := len(state) - 1; i >= 0; i-- { + for _, events := range blocks { + for _, event := range events { + if state[i].EventNID == event { + state = append(state[:i], state[i+1:]...) + } + } + } + } + } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if len(state) > 0 { + // If there's any state left to add then let's add new blocks. var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) if err != nil { @@ -237,7 +272,24 @@ func (d *Database) StateBlockNIDs( func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, + ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := make([]types.StateEntryList, 0, len(entries)) + for i, entry := range entries { + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: eventNIDs, + }) + } + return lists, nil } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { |