aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-02-04 10:39:34 +0000
committerGitHub <noreply@github.com>2022-02-04 10:39:34 +0000
commiteb352a5f6bdb48cb2d795e3fe2cd7d354580a761 (patch)
treedeefb3239e44be8938dcd784cc2094274e1d30ef
parent4d9f5b2e5787d23e1dbcebfda1c6d99d3498ec7e (diff)
Full roomserver input transactional isolation (#2141)
* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input * Better transaction management * Tweak order * Handle cases where the room does not exist * Other fixes * More tweaks * Fill some gaps * Fill in the gaps * good lord it gets worse * Don't roll back transactions when events rejected * Pass through errors properly * Fix bugs * Fix incorrect error check * Don't panic on nil txns * Tweaks * Hopefully fix panics for good in SQLite this time * Fix rollback * Minor bug fixes with latest event updater * Some review comments * Revert "Some review comments" This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e. * Fix a couple of bugs * Clearer commit and rollback results * Remove unnecessary prepares
-rw-r--r--roomserver/internal/helpers/auth.go13
-rw-r--r--roomserver/internal/input/input.go57
-rw-r--r--roomserver/internal/input/input_events.go84
-rw-r--r--roomserver/internal/input/input_latest_events.go18
-rw-r--r--roomserver/internal/input/input_membership.go4
-rw-r--r--roomserver/internal/input/input_missing.go12
-rw-r--r--roomserver/state/state.go17
-rw-r--r--roomserver/storage/interface.go7
-rw-r--r--roomserver/storage/postgres/event_json_table.go5
-rw-r--r--roomserver/storage/postgres/event_state_keys_table.go10
-rw-r--r--roomserver/storage/postgres/event_types_table.go5
-rw-r--r--roomserver/storage/postgres/events_table.go30
-rw-r--r--roomserver/storage/postgres/invite_table.go13
-rw-r--r--roomserver/storage/postgres/membership_table.go67
-rw-r--r--roomserver/storage/postgres/published_table.go10
-rw-r--r--roomserver/storage/postgres/room_aliases_table.go15
-rw-r--r--roomserver/storage/postgres/rooms_table.go27
-rw-r--r--roomserver/storage/postgres/state_block_table.go11
-rw-r--r--roomserver/storage/postgres/state_snapshot_table.go5
-rw-r--r--roomserver/storage/shared/latest_events_updater.go133
-rw-r--r--roomserver/storage/shared/room_updater.go262
-rw-r--r--roomserver/storage/shared/storage.go273
-rw-r--r--roomserver/storage/sqlite3/event_json_table.go11
-rw-r--r--roomserver/storage/sqlite3/event_state_keys_table.go21
-rw-r--r--roomserver/storage/sqlite3/event_types_table.go5
-rw-r--r--roomserver/storage/sqlite3/events_table.go19
-rw-r--r--roomserver/storage/sqlite3/invite_table.go13
-rw-r--r--roomserver/storage/sqlite3/membership_table.go45
-rw-r--r--roomserver/storage/sqlite3/published_table.go10
-rw-r--r--roomserver/storage/sqlite3/room_aliases_table.go15
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go33
-rw-r--r--roomserver/storage/sqlite3/state_block_table.go9
-rw-r--r--roomserver/storage/sqlite3/state_snapshot_table.go3
-rw-r--r--roomserver/storage/sqlite3/storage.go42
-rw-r--r--roomserver/storage/tables/interface.go62
35 files changed, 867 insertions, 499 deletions
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index ddda8081..9af0bf59 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -20,17 +20,22 @@ import (
"sort"
"github.com/matrix-org/dendrite/roomserver/state"
- "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
+type checkForAuthAndSoftFailStorage interface {
+ state.StateResolutionStorage
+ StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
+ RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
+}
+
// CheckForSoftFail returns true if the event should be soft-failed
// and false otherwise. The return error value should be checked before
// the soft-fail bool.
func CheckForSoftFail(
ctx context.Context,
- db storage.Database,
+ db checkForAuthAndSoftFailStorage,
event *gomatrixserverlib.HeaderedEvent,
stateEventIDs []string,
) (bool, error) {
@@ -92,7 +97,7 @@ func CheckForSoftFail(
// Returns the numeric IDs for the auth events.
func CheckAuthEvents(
ctx context.Context,
- db storage.Database,
+ db checkForAuthAndSoftFailStorage,
event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {
@@ -193,7 +198,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
// loadAuthEvents loads the events needed for authentication from the supplied room state.
func loadAuthEvents(
ctx context.Context,
- db storage.Database,
+ db state.StateResolutionStorage,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go
index 7834e2ed..5bdec0a2 100644
--- a/roomserver/internal/input/input.go
+++ b/roomserver/internal/input/input.go
@@ -19,6 +19,7 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
"sync"
"time"
@@ -38,6 +39,19 @@ import (
"github.com/tidwall/gjson"
)
+type retryAction int
+type commitAction int
+
+const (
+ doNotRetry retryAction = iota
+ retryLater
+)
+
+const (
+ commitTransaction commitAction = iota
+ rollbackTransaction
+)
+
var keyContentFields = map[string]string{
"m.room.join_rules": "join_rule",
"m.room.history_visibility": "history_visibility",
@@ -101,7 +115,8 @@ func (r *Inputer) Start() error {
_ = msg.InProgress() // resets the acknowledgement wait timer
defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
- if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil {
+ action, err := r.processRoomEventUsingUpdater(context.Background(), roomID, &inputRoomEvent)
+ if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err)
}
@@ -111,7 +126,12 @@ func (r *Inputer) Start() error {
"type": inputRoomEvent.Event.Type(),
}).Warn("Roomserver failed to process async event")
}
- _ = msg.Ack()
+ switch action {
+ case retryLater:
+ _ = msg.Nak()
+ case doNotRetry:
+ _ = msg.Ack()
+ }
})
},
// NATS wants to acknowledge automatically by default when the message is
@@ -131,6 +151,37 @@ func (r *Inputer) Start() error {
return err
}
+// processRoomEventUsingUpdater opens up a room updater and tries to
+// process the event. It returns whether or not we should positively
+// or negatively acknowledge the event (i.e. for NATS) and an error
+// if it occurred.
+func (r *Inputer) processRoomEventUsingUpdater(
+ ctx context.Context,
+ roomID string,
+ inputRoomEvent *api.InputRoomEvent,
+) (retryAction, error) {
+ roomInfo, err := r.DB.RoomInfo(ctx, roomID)
+ if err != nil {
+ return doNotRetry, fmt.Errorf("r.DB.RoomInfo: %w", err)
+ }
+ updater, err := r.DB.GetRoomUpdater(ctx, roomInfo)
+ if err != nil {
+ return retryLater, fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
+ }
+ action, err := r.processRoomEvent(ctx, updater, inputRoomEvent)
+ switch action {
+ case commitTransaction:
+ if cerr := updater.Commit(); cerr != nil {
+ return retryLater, fmt.Errorf("updater.Commit: %w", cerr)
+ }
+ case rollbackTransaction:
+ if rerr := updater.Rollback(); rerr != nil {
+ return retryLater, fmt.Errorf("updater.Rollback: %w", rerr)
+ }
+ }
+ return doNotRetry, err
+}
+
// InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents(
ctx context.Context,
@@ -177,7 +228,7 @@ func (r *Inputer) InputRoomEvents(
worker.Act(nil, func() {
defer eventsInProgress.Delete(index)
defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec()
- err := r.processRoomEvent(ctx, &inputRoomEvent)
+ _, err := r.processRoomEventUsingUpdater(ctx, roomID, &inputRoomEvent)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
sentry.CaptureException(err)
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 16703616..f3fa83d8 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -67,14 +68,15 @@ var processRoomEventDuration = prometheus.NewHistogramVec(
// nolint:gocyclo
func (r *Inputer) processRoomEvent(
ctx context.Context,
+ updater *shared.RoomUpdater,
input *api.InputRoomEvent,
-) (err error) {
+) (commitAction, error) {
select {
case <-ctx.Done():
// Before we do anything, make sure the context hasn't expired for this pending task.
// If it has then we'll give up straight away — it's probably a synchronous input
// request and the caller has already given up, but the inbox task was still queued.
- return context.DeadlineExceeded
+ return rollbackTransaction, context.DeadlineExceeded
default:
}
@@ -107,7 +109,7 @@ func (r *Inputer) processRoomEvent(
// if we have already got this event then do not process it again, if the input kind is an outlier.
// Outliers contain no extra information which may warrant a re-processing.
if input.Kind == api.KindOutlier {
- evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()})
+ evs, err2 := updater.EventsFromIDs(ctx, []string{event.EventID()})
if err2 == nil && len(evs) == 1 {
// check hash matches if we're on early room versions where the event ID was a random string
idFormat, err2 := headered.RoomVersion.EventIDFormat()
@@ -116,11 +118,11 @@ func (r *Inputer) processRoomEvent(
case gomatrixserverlib.EventIDFormatV1:
if bytes.Equal(event.EventReference().EventSHA256, evs[0].EventReference().EventSHA256) {
logger.Debugf("Already processed event; ignoring")
- return nil
+ return rollbackTransaction, nil
}
default:
logger.Debugf("Already processed event; ignoring")
- return nil
+ return rollbackTransaction, nil
}
}
}
@@ -134,8 +136,8 @@ func (r *Inputer) processRoomEvent(
AuthEventIDs: event.AuthEventIDs(),
PrevEventIDs: event.PrevEventIDs(),
}
- if err = r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
- return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
+ if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil {
+ return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err)
}
}
missingAuth := len(missingRes.MissingAuthEventIDs) > 0
@@ -146,8 +148,8 @@ func (r *Inputer) processRoomEvent(
RoomID: event.RoomID(),
ExcludeSelf: true,
}
- if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
- return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
+ if err := r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
+ return rollbackTransaction, fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
}
// Sort all of the servers into a map so that we can randomise
// their order. Then make sure that the input origin and the
@@ -176,8 +178,8 @@ func (r *Inputer) processRoomEvent(
isRejected := false
authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownEvents := map[string]*types.Event{}
- if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
- return fmt.Errorf("r.fetchAuthEvents: %w", err)
+ if err := r.fetchAuthEvents(ctx, updater, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil {
+ return rollbackTransaction, fmt.Errorf("r.fetchAuthEvents: %w", err)
}
// Check if the event is allowed by its auth events. If it isn't then
@@ -193,7 +195,7 @@ func (r *Inputer) processRoomEvent(
authEventNIDs := make([]types.EventNID, 0, len(authEventIDs))
for _, authEventID := range authEventIDs {
if _, ok := knownEvents[authEventID]; !ok {
- return fmt.Errorf("missing auth event %s", authEventID)
+ return rollbackTransaction, fmt.Errorf("missing auth event %s", authEventID)
}
authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID)
}
@@ -202,7 +204,8 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindNew {
// Check that the event passes authentication checks based on the
// current room state.
- softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
+ var err error
+ softfail, err = helpers.CheckForSoftFail(ctx, updater, headered, input.StateEventIDs)
if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event")
}
@@ -227,7 +230,7 @@ func (r *Inputer) processRoomEvent(
origin: input.Origin,
inputer: r,
queryer: r.Queryer,
- db: r.DB,
+ db: updater,
federation: r.FSAPI,
keys: r.KeyRing,
roomsMu: internal.NewMutexByRoom(),
@@ -235,7 +238,7 @@ func (r *Inputer) processRoomEvent(
hadEvents: map[string]bool{},
haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{},
}
- if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
+ if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil {
isRejected = true
rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err)
} else {
@@ -248,16 +251,16 @@ func (r *Inputer) processRoomEvent(
}
// Store the event.
- _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
+ _, _, stateAtEvent, redactionEvent, redactedEventID, err := updater.StoreEvent(ctx, event, authEventNIDs, isRejected)
if err != nil {
- return fmt.Errorf("r.DB.StoreEvent: %w", err)
+ return rollbackTransaction, fmt.Errorf("updater.StoreEvent: %w", err)
}
// if storing this event results in it being redacted then do so.
if !isRejected && redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, event)
if rerr != nil {
- return fmt.Errorf("eventutil.RedactEvent: %w", rerr)
+ return rollbackTransaction, fmt.Errorf("eventutil.RedactEvent: %w", rerr)
}
event = r
}
@@ -268,23 +271,23 @@ func (r *Inputer) processRoomEvent(
if input.Kind == api.KindOutlier {
logger.Debug("Stored outlier")
hooks.Run(hooks.KindNewEventPersisted, headered)
- return nil
+ return commitTransaction, nil
}
- roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID())
+ roomInfo, err := updater.RoomInfo(ctx, event.RoomID())
if err != nil {
- return fmt.Errorf("r.DB.RoomInfo: %w", err)
+ return rollbackTransaction, fmt.Errorf("updater.RoomInfo: %w", err)
}
if roomInfo == nil {
- return fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID())
+ return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID())
}
if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
- err = r.calculateAndSetState(ctx, input, roomInfo, &stateAtEvent, event, isRejected)
+ err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected)
if err != nil {
- return fmt.Errorf("r.calculateAndSetState: %w", err)
+ return rollbackTransaction, fmt.Errorf("r.calculateAndSetState: %w", err)
}
}
@@ -294,13 +297,14 @@ func (r *Inputer) processRoomEvent(
"soft_fail": softfail,
"missing_prev": missingPrev,
}).Warn("Stored rejected event")
- return rejectionErr
+ return commitTransaction, rejectionErr
}
switch input.Kind {
case api.KindNew:
if err = r.updateLatestEvents(
ctx, // context
+ updater, // room updater
roomInfo, // room info for the room being updated
stateAtEvent, // state at event (below)
event, // event
@@ -308,7 +312,7 @@ func (r *Inputer) processRoomEvent(
input.TransactionID, // transaction ID
input.HasState, // rewrites state?
); err != nil {
- return fmt.Errorf("r.updateLatestEvents: %w", err)
+ return rollbackTransaction, fmt.Errorf("r.updateLatestEvents: %w", err)
}
case api.KindOld:
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
@@ -320,7 +324,7 @@ func (r *Inputer) processRoomEvent(
},
})
if err != nil {
- return fmt.Errorf("r.WriteOutputEvents (old): %w", err)
+ return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (old): %w", err)
}
}
@@ -339,14 +343,14 @@ func (r *Inputer) processRoomEvent(
},
})
if err != nil {
- return fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
+ return rollbackTransaction, fmt.Errorf("r.WriteOutputEvents (redactions): %w", err)
}
}
// Everything was OK — the latest events updater didn't error and
// we've sent output events. Finally, generate a hook call.
hooks.Run(hooks.KindNewEventPersisted, headered)
- return nil
+ return commitTransaction, nil
}
// fetchAuthEvents will check to see if any of the
@@ -358,6 +362,7 @@ func (r *Inputer) processRoomEvent(
// they are now in the database.
func (r *Inputer) fetchAuthEvents(
ctx context.Context,
+ updater *shared.RoomUpdater,
logger *logrus.Entry,
event *gomatrixserverlib.HeaderedEvent,
auth *gomatrixserverlib.AuthEvents,
@@ -375,7 +380,7 @@ func (r *Inputer) fetchAuthEvents(
}
for _, authEventID := range authEventIDs {
- authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID})
+ authEvents, err := updater.EventsFromIDs(ctx, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{}
continue
@@ -454,9 +459,9 @@ func (r *Inputer) fetchAuthEvents(
}
// Finally, store the event in the database.
- eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
+ eventNID, _, _, _, _, err := updater.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
if err != nil {
- return fmt.Errorf("r.DB.StoreEvent: %w", err)
+ return fmt.Errorf("updater.StoreEvent: %w", err)
}
// Now we know about this event, it was stored and the signatures were OK.
@@ -471,6 +476,7 @@ func (r *Inputer) fetchAuthEvents(
func (r *Inputer) calculateAndSetState(
ctx context.Context,
+ updater *shared.RoomUpdater,
input *api.InputRoomEvent,
roomInfo *types.RoomInfo,
stateAtEvent *types.StateAtEvent,
@@ -478,14 +484,14 @@ func (r *Inputer) calculateAndSetState(
isRejected bool,
) error {
var err error
- roomState := state.NewStateResolution(r.DB, roomInfo)
+ roomState := state.NewStateResolution(updater, roomInfo)
if input.HasState {
// Check here if we think we're in the room already.
stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID
// Request join memberships only for local users only.
- if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
+ if joinEventNIDs, err = updater.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
// If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it.
@@ -495,13 +501,13 @@ func (r *Inputer) calculateAndSetState(
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
var entries []types.StateEntry
- if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
- return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err)
+ if entries, err = updater.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
+ return fmt.Errorf("updater.StateEntriesForEventIDs: %w", err)
}
entries = types.DeduplicateStateEntries(entries)
- if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
- return fmt.Errorf("r.DB.AddState: %w", err)
+ if stateAtEvent.BeforeStateSnapshotNID, err = updater.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
+ return fmt.Errorf("updater.AddState: %w", err)
}
} else {
stateAtEvent.Overwrite = false
@@ -512,7 +518,7 @@ func (r *Inputer) calculateAndSetState(
}
}
- err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
+ err = updater.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
if err != nil {
return fmt.Errorf("r.DB.SetState: %w", err)
}
diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go
index 6137941e..5173d3ab 100644
--- a/roomserver/internal/input/input_latest_events.go
+++ b/roomserver/internal/input/input_latest_events.go
@@ -20,7 +20,6 @@ import (
"context"
"fmt"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
@@ -48,6 +47,7 @@ import (
// Can only be called once at a time
func (r *Inputer) updateLatestEvents(
ctx context.Context,
+ updater *shared.RoomUpdater,
roomInfo *types.RoomInfo,
stateAtEvent types.StateAtEvent,
event *gomatrixserverlib.Event,
@@ -55,13 +55,6 @@ func (r *Inputer) updateLatestEvents(
transactionID *api.TransactionID,
rewritesState bool,
) (err error) {
- updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
- if err != nil {
- return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
- }
- succeeded := false
- defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
-
u := latestEventsUpdater{
ctx: ctx,
api: r,
@@ -78,7 +71,6 @@ func (r *Inputer) updateLatestEvents(
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
}
- succeeded = true
return
}
@@ -89,7 +81,7 @@ func (r *Inputer) updateLatestEvents(
type latestEventsUpdater struct {
ctx context.Context
api *Inputer
- updater *shared.LatestEventsUpdater
+ updater *shared.RoomUpdater
roomInfo *types.RoomInfo
stateAtEvent types.StateAtEvent
event *gomatrixserverlib.Event
@@ -199,7 +191,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
func (u *latestEventsUpdater) latestState() error {
var err error
- roomState := state.NewStateResolution(u.api.DB, u.roomInfo)
+ roomState := state.NewStateResolution(u.updater, u.roomInfo)
// Work out if the state at the extremities has actually changed
// or not. If they haven't then we won't bother doing all of the
@@ -413,7 +405,7 @@ func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.Ro
if len(extraEventIDs) == 0 {
return nil, nil
}
- extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs)
+ extraEvents, err := u.updater.EventsFromIDs(u.ctx, extraEventIDs)
if err != nil {
return nil, err
}
@@ -436,7 +428,7 @@ func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error)
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
- return u.api.DB.EventIDs(u.ctx, stateEventNIDs)
+ return u.updater.EventIDs(u.ctx, stateEventNIDs)
}
type eventNIDSorter []types.EventNID
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
index 2511097d..ff3ed7e5 100644
--- a/roomserver/internal/input/input_membership.go
+++ b/roomserver/internal/input/input_membership.go
@@ -31,7 +31,7 @@ import (
// consumers about the invites added or retired by the change in current state.
func (r *Inputer) updateMemberships(
ctx context.Context,
- updater *shared.LatestEventsUpdater,
+ updater *shared.RoomUpdater,
removed, added []types.StateEntry,
) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added)
@@ -79,7 +79,7 @@ func (r *Inputer) updateMemberships(
}
func (r *Inputer) updateMembership(
- updater *shared.LatestEventsUpdater,
+ updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent,
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index d401fa0e..4cd2b3de 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -11,7 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/query"
- "github.com/matrix-org/dendrite/roomserver/storage"
+ "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -19,7 +19,7 @@ import (
type missingStateReq struct {
origin gomatrixserverlib.ServerName
- db storage.Database
+ db *shared.RoomUpdater
inputer *Inputer
queryer *query.Queryer
keys gomatrixserverlib.JSONVerifier
@@ -78,7 +78,7 @@ func (t *missingStateReq) processEventWithMissingState(
// we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled
// in the gap in the DAG
for _, newEvent := range newEvents {
- err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
+ _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindNew,
Event: newEvent.Headered(roomVersion),
Origin: t.origin,
@@ -187,7 +187,7 @@ func (t *missingStateReq) processEventWithMissingState(
}
// TODO: we could do this concurrently?
for _, ire := range outlierRoomEvents {
- if err = t.inputer.processRoomEvent(ctx, &ire); err != nil {
+ if _, err = t.inputer.processRoomEvent(ctx, t.db, &ire); err != nil {
return fmt.Errorf("t.inputer.processRoomEvent[outlier]: %w", err)
}
}
@@ -200,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState(
stateIDs = append(stateIDs, event.EventID())
}
- err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
+ _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld,
Event: backwardsExtremity.Headered(roomVersion),
Origin: t.origin,
@@ -217,7 +217,7 @@ func (t *missingStateReq) processEventWithMissingState(
// they will automatically fast-forward based on the room state at the
// extremity in the last step.
for _, newEvent := range newEvents {
- err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{
+ _, err = t.inputer.processRoomEvent(ctx, t.db, &api.InputRoomEvent{
Kind: api.KindOld,
Event: newEvent.Headered(roomVersion),
Origin: t.origin,
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index 15d592b4..e5f69521 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -22,7 +22,6 @@ import (
"sort"
"time"
- "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
@@ -30,13 +29,25 @@ import (
"github.com/matrix-org/gomatrixserverlib"
)
+type StateResolutionStorage interface {
+ EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
+ EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
+ StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
+ StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
+ SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
+ StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
+ StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
+ AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
+ Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
+}
+
type StateResolution struct {
- db storage.Database
+ db StateResolutionStorage
roomInfo *types.RoomInfo
events map[types.EventNID]*gomatrixserverlib.Event
}
-func NewStateResolution(db storage.Database, roomInfo *types.RoomInfo) StateResolution {
+func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
return StateResolution{
db: db,
roomInfo: roomInfo,
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 15764366..a9851e05 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -86,11 +86,10 @@ type Database interface {
// Lookup the event IDs for a batch of event numeric IDs.
// Returns an error if the retrieval went wrong.
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
- // Look up the latest events in a room in preparation for an update.
- // The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
- // Returns the latest events in the room and the last eventID sent to the log along with an updater.
+ // Opens and returns a room updater, which locks the room and opens a transaction.
+ // The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error.
// If this returns an error then no further action is required.
- GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error)
+ GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
// Look up event references for the latest events in the room and the current state snapshot.
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
// Returns an error if there was a problem talking to the database.
diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go
index 32e45782..433e445d 100644
--- a/roomserver/storage/postgres/event_json_table.go
+++ b/roomserver/storage/postgres/event_json_table.go
@@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
}
func (s *eventJSONStatements) BulkSelectEventJSON(
- ctx context.Context, eventNIDs []types.EventNID,
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) {
- rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
+ rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go
index 3a7cf03e..762b3a1f 100644
--- a/roomserver/storage/postgres/event_state_keys_table.go
+++ b/roomserver/storage/postgres/event_state_keys_table.go
@@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
- ctx context.Context, eventStateKeys []string,
+ ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
- rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
+ rows, err := stmt.QueryContext(
ctx, pq.StringArray(eventStateKeys),
)
if err != nil {
@@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
- ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
+ ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}
- rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
+ rows, err := stmt.QueryContext(ctx, nIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/event_types_table.go b/roomserver/storage/postgres/event_types_table.go
index e558072a..1d5de582 100644
--- a/roomserver/storage/postgres/event_types_table.go
+++ b/roomserver/storage/postgres/event_types_table.go
@@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
}
func (s *eventTypeStatements) BulkSelectEventTypeNID(
- ctx context.Context, eventTypes []string,
+ ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
- rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
+ rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go
index 778cd8d7..6c384775 100644
--- a/roomserver/storage/postgres/events_table.go
+++ b/roomserver/storage/postgres/events_table.go
@@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID(
- ctx context.Context, eventIDs []string,
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) {
- rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
+ rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByNID lookups a list of state events by event NID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID(
- ctx context.Context, eventNIDs []types.EventNID,
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
- rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
+ rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
if err != nil {
return nil, err
}
@@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID(
- ctx context.Context, eventIDs []string,
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) {
- rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
+ rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
}
// bulkSelectEventID returns a map from numeric event ID to string event ID.
-func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
- rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
+func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
+ rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
@@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
-func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
- rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
+func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
+ rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
}
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
- ctx context.Context, eventNIDs []types.EventNID,
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) {
- rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
+ stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
+ rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/invite_table.go b/roomserver/storage/postgres/invite_table.go
index 344302c8..176c16e4 100644
--- a/roomserver/storage/postgres/invite_table.go
+++ b/roomserver/storage/postgres/invite_table.go
@@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
}
func (s *inviteStatements) InsertInviteEvent(
- ctx context.Context,
- txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
+ ctx context.Context, txn *sql.Tx,
+ inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
@@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
}
func (s *inviteStatements) UpdateInviteRetired(
- ctx context.Context,
- txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
@@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) {
- rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
+ stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
+ rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID,
)
if err != nil {
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index b0d906c8..48c2c35c 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
}
func (s *membershipStatements) InsertMembership(
- ctx context.Context,
- txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
@@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
}
func (s *membershipStatements) SelectMembershipForUpdate(
- ctx context.Context,
- txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) {
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID,
@@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
}
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
- err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
+ stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
+ err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten)
return
}
func (s *membershipStatements) SelectMembershipsFromRoom(
- ctx context.Context, roomNID types.RoomNID, localOnly bool,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt
if localOnly {
@@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else {
stmt = s.selectMembershipsFromRoomStmt
}
+ stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil {
return
@@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
}
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows
@@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
+ stmt = sqlutil.TxStmt(txn, stmt)
rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
@@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
}
func (s *membershipStatements) UpdateMembership(
- ctx context.Context,
- txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
@@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
}
func (s *membershipStatements) SelectRoomsWithMembership(
- ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
+ ctx context.Context, txn *sql.Tx,
+ userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
- rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
+ stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
+ rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil {
return nil, err
}
@@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
-func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
+func (s *membershipStatements) SelectJoinedUsersSetForRooms(
+ ctx context.Context, txn *sql.Tx,
+ roomNIDs []types.RoomNID,
+) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
- rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
+ stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
+ rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
if err != nil {
return nil, err
}
@@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err()
}
-func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
- rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
+func (s *membershipStatements) SelectKnownUsers(
+ ctx context.Context, txn *sql.Tx,
+ userID types.EventStateKeyNID, searchString string, limit int,
+) ([]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
+ rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
@@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
}
func (s *membershipStatements) UpdateForgetMembership(
- ctx context.Context,
- txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
- forget bool,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, roomNID, targetUserNID, forget,
@@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
return err
}
-func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
+func (s *membershipStatements) SelectLocalServerInRoom(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID,
+) (bool, error) {
var nid types.RoomNID
- err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
+ stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
+ err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
@@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil
}
-func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
+func (s *membershipStatements) SelectServerInRoom(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
+) (bool, error) {
var nid types.RoomNID
- err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
+ stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
+ err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go
index 8deb6844..15985fcd 100644
--- a/roomserver/storage/postgres/published_table.go
+++ b/roomserver/storage/postgres/published_table.go
@@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
}
func (s *publishedStatements) SelectPublishedFromRoomID(
- ctx context.Context, roomID string,
+ ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) {
- err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
+ stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
+ err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows {
return false, nil
}
@@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
}
func (s *publishedStatements) SelectAllPublishedRooms(
- ctx context.Context, published bool,
+ ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) {
- rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
+ stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
+ rows, err := stmt.QueryContext(ctx, published)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go
index 031825fe..d13df8e7 100644
--- a/roomserver/storage/postgres/room_aliases_table.go
+++ b/roomserver/storage/postgres/room_aliases_table.go
@@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
- ctx context.Context, alias string,
+ ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) {
- err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
+ stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
+ err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows {
return "", nil
}
@@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
}
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
- ctx context.Context, roomID string,
+ ctx context.Context, txn *sql.Tx, roomID string,
) ([]string, error) {
- rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
+ stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
+ rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
@@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
}
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
- ctx context.Context, alias string,
+ ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) {
- err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
+ stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
+ err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows {
return "", nil
}
diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go
index f51eba4d..b2685084 100644
--- a/roomserver/storage/postgres/rooms_table.go
+++ b/roomserver/storage/postgres/rooms_table.go
@@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
-func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
- rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
+func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
+ rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
@@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
return types.RoomNID(roomNID), err
}
-func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
+func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDs pq.Int64Array
- err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
+ stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
+ err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
)
if err == sql.ErrNoRows {
@@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array
var stateSnapshotNID int64
- stmt := s.selectLatestEventNIDsStmt
+ stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil {
return nil, 0, err
@@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
}
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
- ctx context.Context, roomNIDs []types.RoomNID,
+ ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
- rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
+ stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
+ rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
if err != nil {
return nil, err
}
@@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil
}
-func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
+func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
var array pq.Int64Array
for _, nid := range roomNIDs {
array = append(array, int64(nid))
}
- rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
+ rows, err := stmt.QueryContext(ctx, array)
if err != nil {
return nil, err
}
@@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil
}
-func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
+func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
var array pq.StringArray
for _, roomID := range roomIDs {
array = append(array, roomID)
}
- rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
+ rows, err := stmt.QueryContext(ctx, array)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go
index 27d85e83..6f8f9e1b 100644
--- a/roomserver/storage/postgres/state_block_table.go
+++ b/roomserver/storage/postgres/state_block_table.go
@@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
}
func (s *stateBlockStatements) BulkInsertStateData(
- ctx context.Context,
- txn *sql.Tx,
+ ctx context.Context, txn *sql.Tx,
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
@@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
for _, e := range entries {
nids = append(nids, e.EventNID)
}
- err = s.insertStateDataStmt.QueryRowContext(
+ stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
+ err = stmt.QueryRowContext(
ctx, nids.Hash(), eventNIDsAsArray(nids),
).Scan(&id)
return
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
- ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
+ ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) {
- rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
+ rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go
index 4fc0fa48..ce9f2463 100644
--- a/roomserver/storage/postgres/state_snapshot_table.go
+++ b/roomserver/storage/postgres/state_snapshot_table.go
@@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
}
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
- ctx context.Context, stateNIDs []types.StateSnapshotNID,
+ ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
- rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
+ stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
+ rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go
deleted file mode 100644
index 36865081..00000000
--- a/roomserver/storage/shared/latest_events_updater.go
+++ /dev/null
@@ -1,133 +0,0 @@
-package shared
-
-import (
- "context"
- "database/sql"
- "fmt"
-
- "github.com/matrix-org/dendrite/roomserver/types"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-type LatestEventsUpdater struct {
- transaction
- d *Database
- roomInfo types.RoomInfo
- latestEvents []types.StateAtEventAndReference
- lastEventIDSent string
- currentStateSnapshotNID types.StateSnapshotNID
-}
-
-func rollback(txn *sql.Tx) {
- if txn == nil {
- return
- }
- txn.Rollback() // nolint: errcheck
-}
-
-func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
- eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
- d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
- if err != nil {
- rollback(txn)
- return nil, err
- }
- stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
- if err != nil {
- rollback(txn)
- return nil, err
- }
- var lastEventIDSent string
- if lastEventNIDSent != 0 {
- lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
- if err != nil {
- rollback(txn)
- return nil, err
- }
- }
- return &LatestEventsUpdater{
- transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
- }, nil
-}
-
-// RoomVersion implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
- return u.roomInfo.RoomVersion
-}
-
-// LatestEvents implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
- return u.latestEvents
-}
-
-// LastEventIDSent implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) LastEventIDSent() string {
- return u.lastEventIDSent
-}
-
-// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
- return u.currentStateSnapshotNID
-}
-
-// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
-func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
- for _, ref := range previousEventReferences {
- if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
- return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
- }
- }
- return nil
-}
-
-// IsReferenced implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
- err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
- if err == nil {
- return true, nil
- }
- if err == sql.ErrNoRows {
- return false, nil
- }
- return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
-}
-
-// SetLatestEvents implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) SetLatestEvents(
- roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
- currentStateSnapshotNID types.StateSnapshotNID,
-) error {
- eventNIDs := make([]types.EventNID, len(latest))
- for i := range latest {
- eventNIDs[i] = latest[i].EventNID
- }
- return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
- if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
- return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
- }
- if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
- if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
- roomInfo.StateSnapshotNID = currentStateSnapshotNID
- roomInfo.IsStub = false
- u.d.Cache.StoreRoomInfo(roomID, roomInfo)
- }
- }
- return nil
- })
-}
-
-// HasEventBeenSent implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
- return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
-}
-
-// MarkEventAsSent implements types.RoomRecentEventsUpdater
-func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
- return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
- return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
- })
-}
-
-func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
- return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
-}
diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go
new file mode 100644
index 00000000..bb9f5dc6
--- /dev/null
+++ b/roomserver/storage/shared/room_updater.go
@@ -0,0 +1,262 @@
+package shared
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/roomserver/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+type RoomUpdater struct {
+ transaction
+ d *Database
+ roomInfo *types.RoomInfo
+ latestEvents []types.StateAtEventAndReference
+ lastEventIDSent string
+ currentStateSnapshotNID types.StateSnapshotNID
+}
+
+func rollback(txn *sql.Tx) {
+ if txn == nil {
+ return
+ }
+ txn.Rollback() // nolint: errcheck
+}
+
+func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) {
+ // If the roomInfo is nil then that means that the room doesn't exist
+ // yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that
+ // would involve locking a row on the table that doesn't exist. Instead
+ // we will just run with a normal database transaction. It'll either
+ // succeed, processing a create event which creates the room, or it won't.
+ if roomInfo == nil {
+ return &RoomUpdater{
+ transaction{ctx, txn}, d, nil, nil, "", 0,
+ }, nil
+ }
+
+ eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
+ d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
+ if err != nil {
+ rollback(txn)
+ return nil, err
+ }
+ stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
+ if err != nil {
+ rollback(txn)
+ return nil, err
+ }
+ var lastEventIDSent string
+ if lastEventNIDSent != 0 {
+ lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
+ if err != nil {
+ rollback(txn)
+ return nil, err
+ }
+ }
+ return &RoomUpdater{
+ transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
+ }, nil
+}
+
+// Implements sqlutil.Transaction
+func (u *RoomUpdater) Commit() error {
+ if u.txn == nil { // SQLite mode probably
+ return nil
+ }
+ return u.txn.Commit()
+}
+
+// Implements sqlutil.Transaction
+func (u *RoomUpdater) Rollback() error {
+ if u.txn == nil { // SQLite mode probably
+ return nil
+ }
+ return u.txn.Rollback()
+}
+
+// RoomVersion implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
+ return u.roomInfo.RoomVersion
+}
+
+// LatestEvents implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
+ return u.latestEvents
+}
+
+// LastEventIDSent implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) LastEventIDSent() string {
+ return u.lastEventIDSent
+}
+
+// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
+ return u.currentStateSnapshotNID
+}
+
+// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
+func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
+ return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ for _, ref := range previousEventReferences {
+ if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
+ return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
+func (u *RoomUpdater) Events(
+ ctx context.Context, eventNIDs []types.EventNID,
+) ([]types.Event, error) {
+ return u.d.events(ctx, u.txn, eventNIDs)
+}
+
+func (u *RoomUpdater) SnapshotNIDFromEventID(
+ ctx context.Context, eventID string,
+) (types.StateSnapshotNID, error) {
+ return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
+}
+
+func (u *RoomUpdater) StoreEvent(
+ ctx context.Context, event *gomatrixserverlib.Event,
+ authEventNIDs []types.EventNID, isRejected bool,
+) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
+ return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
+}
+
+func (u *RoomUpdater) StateBlockNIDs(
+ ctx context.Context, stateNIDs []types.StateSnapshotNID,
+) ([]types.StateBlockNIDList, error) {
+ return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs)
+}
+
+func (u *RoomUpdater) StateEntries(
+ ctx context.Context, stateBlockNIDs []types.StateBlockNID,
+) ([]types.StateEntryList, error) {
+ return u.d.stateEntries(ctx, u.txn, stateBlockNIDs)
+}
+
+func (u *RoomUpdater) StateEntriesForTuples(
+ ctx context.Context,
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntryList, error) {
+ return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples)
+}
+
+func (u *RoomUpdater) AddState(
+ ctx context.Context,
+ roomNID types.RoomNID,
+ stateBlockNIDs []types.StateBlockNID,
+ state []types.StateEntry,
+) (stateNID types.StateSnapshotNID, err error) {
+ return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
+}
+
+func (u *RoomUpdater) SetState(
+ ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
+) error {
+ return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
+ })
+}
+
+func (u *RoomUpdater) EventTypeNIDs(
+ ctx context.Context, eventTypes []string,
+) (map[string]types.EventTypeNID, error) {
+ return u.d.eventTypeNIDs(ctx, u.txn, eventTypes)
+}
+
+func (u *RoomUpdater) EventStateKeyNIDs(
+ ctx context.Context, eventStateKeys []string,
+) (map[string]types.EventStateKeyNID, error) {
+ return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys)
+}
+
+func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
+ return u.d.roomInfo(ctx, u.txn, roomID)
+}
+
+func (u *RoomUpdater) EventIDs(
+ ctx context.Context, eventNIDs []types.EventNID,
+) (map[types.EventNID]string, error) {
+ return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
+}
+
+func (u *RoomUpdater) StateAtEventIDs(
+ ctx context.Context, eventIDs []string,
+) ([]types.StateAtEvent, error) {
+ return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
+}
+
+func (u *RoomUpdater) StateEntriesForEventIDs(
+ ctx context.Context, eventIDs []string,
+) ([]types.StateEntry, error) {
+ return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
+}
+
+func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
+ return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
+}
+
+func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
+ ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
+) ([]types.EventNID, error) {
+ return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
+}
+
+// IsReferenced implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
+ err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
+ if err == nil {
+ return true, nil
+ }
+ if err == sql.ErrNoRows {
+ return false, nil
+ }
+ return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
+}
+
+// SetLatestEvents implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) SetLatestEvents(
+ roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
+ currentStateSnapshotNID types.StateSnapshotNID,
+) error {
+ eventNIDs := make([]types.EventNID, len(latest))
+ for i := range latest {
+ eventNIDs[i] = latest[i].EventNID
+ }
+ return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
+ return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
+ }
+ if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
+ if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
+ roomInfo.StateSnapshotNID = currentStateSnapshotNID
+ roomInfo.IsStub = false
+ u.d.Cache.StoreRoomInfo(roomID, roomInfo)
+ }
+ }
+ return nil
+ })
+}
+
+// HasEventBeenSent implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
+ return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
+}
+
+// MarkEventAsSent implements types.RoomRecentEventsUpdater
+func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
+ return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
+ })
+}
+
+func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
+ return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
+}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index d4c5ebb5..127cd1f5 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -26,23 +26,23 @@ import (
const redactionsArePermanent = true
type Database struct {
- DB *sql.DB
- Cache caching.RoomServerCaches
- Writer sqlutil.Writer
- EventsTable tables.Events
- EventJSONTable tables.EventJSON
- EventTypesTable tables.EventTypes
- EventStateKeysTable tables.EventStateKeys
- RoomsTable tables.Rooms
- StateSnapshotTable tables.StateSnapshot
- StateBlockTable tables.StateBlock
- RoomAliasesTable tables.RoomAliases
- PrevEventsTable tables.PreviousEvents
- InvitesTable tables.Invites
- MembershipTable tables.Membership
- PublishedTable tables.Published
- RedactionsTable tables.Redactions
- GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error)
+ DB *sql.DB
+ Cache caching.RoomServerCaches
+ Writer sqlutil.Writer
+ EventsTable tables.Events
+ EventJSONTable tables.EventJSON
+ EventTypesTable tables.EventTypes
+ EventStateKeysTable tables.EventStateKeys
+ RoomsTable tables.Rooms
+ StateSnapshotTable tables.StateSnapshot
+ StateBlockTable tables.StateBlock
+ RoomAliasesTable tables.RoomAliases
+ PrevEventsTable tables.PreviousEvents
+ InvitesTable tables.Invites
+ MembershipTable tables.Membership
+ PublishedTable tables.Published
+ RedactionsTable tables.Redactions
+ GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
func (d *Database) SupportsConcurrentRoomInputs() bool {
@@ -52,6 +52,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
+ return d.eventTypeNIDs(ctx, nil, eventTypes)
+}
+
+func (d *Database) eventTypeNIDs(
+ ctx context.Context, txn *sql.Tx, eventTypes []string,
+) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID)
remaining := []string{}
for _, eventType := range eventTypes {
@@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs(
}
}
if len(remaining) > 0 {
- nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining)
+ nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
@@ -77,12 +83,18 @@ func (d *Database) EventTypeNIDs(
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
- return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
+ return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
}
func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
+ return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
+}
+
+func (d *Database) eventStateKeyNIDs(
+ ctx context.Context, txn *sql.Tx, eventStateKeys []string,
+) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID)
remaining := []string{}
for _, eventStateKey := range eventStateKeys {
@@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs(
}
}
if len(remaining) > 0 {
- nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining)
+ nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
@@ -108,7 +120,7 @@ func (d *Database) EventStateKeyNIDs(
func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
- return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
+ return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
}
func (d *Database) StateEntriesForTuples(
@@ -116,15 +128,23 @@ func (d *Database) StateEntriesForTuples(
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
+ return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples)
+}
+
+func (d *Database) stateEntriesForTuples(
+ ctx context.Context, txn *sql.Tx,
+ stateBlockNIDs []types.StateBlockNID,
+ stateKeyTuples []types.StateKeyTuple,
+) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
- ctx, stateBlockNIDs,
+ ctx, txn, stateBlockNIDs,
)
if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
}
lists := []types.StateEntryList{}
for i, entry := range entries {
- entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples)
+ entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
}
@@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples(
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
+ return d.roomInfo(ctx, nil, roomID)
+}
+
+func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil
}
- roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
+ roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo)
@@ -154,12 +178,21 @@ func (d *Database) AddState(
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
+ return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
+}
+
+func (d *Database) addState(
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID,
+ stateBlockNIDs []types.StateBlockNID,
+ state []types.StateEntry,
+) (stateNID types.StateSnapshotNID, err error) {
if len(stateBlockNIDs) > 0 && len(state) > 0 {
// Check to see if the event already appears in any of the existing state
// blocks. If it does then we should not add it again, as this will just
// result in excess state blocks and snapshots.
// TODO: Investigate why this is happening - probably input_events.go!
- blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
+ blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
if berr != nil {
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
}
@@ -180,7 +213,7 @@ func (d *Database) AddState(
}
}
}
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
if len(state) > 0 {
// If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID
@@ -205,7 +238,13 @@ func (d *Database) AddState(
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
- return d.EventsTable.BulkSelectEventNID(ctx, eventIDs)
+ return d.eventNIDs(ctx, nil, eventIDs)
+}
+
+func (d *Database) eventNIDs(
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
+) (map[string]types.EventNID, error) {
+ return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
}
func (d *Database) SetState(
@@ -219,24 +258,34 @@ func (d *Database) SetState(
func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
- return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs)
+ return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
}
func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
- _, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID)
+ return d.snapshotNIDFromEventID(ctx, nil, eventID)
+}
+
+func (d *Database) snapshotNIDFromEventID(
+ ctx context.Context, txn *sql.Tx, eventID string,
+) (types.StateSnapshotNID, error) {
+ _, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
return stateNID, err
}
func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
- return d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
+ return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
}
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
- nidMap, err := d.EventNIDs(ctx, eventIDs)
+ return d.eventsFromIDs(ctx, nil, eventIDs)
+}
+
+func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
+ nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
if err != nil {
return nil, err
}
@@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
nids = append(nids, nid)
}
- return d.Events(ctx, nids)
+ return d.events(ctx, txn, nids)
}
func (d *Database) LatestEventIDs(
@@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs(
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
- return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
+ return d.stateBlockNIDs(ctx, nil, stateNIDs)
+}
+
+func (d *Database) stateBlockNIDs(
+ ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
+) ([]types.StateBlockNIDList, error) {
+ return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
}
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
+ return d.stateEntries(ctx, nil, stateBlockNIDs)
+}
+
+func (d *Database) stateEntries(
+ ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
+) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
- ctx, stateBlockNIDs,
+ ctx, txn, stateBlockNIDs,
)
if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
}
lists := make([]types.StateEntryList, 0, len(entries))
for i, entry := range entries {
- eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil)
+ eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
}
@@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
}
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
- return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias)
+ return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
}
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
- return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID)
+ return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
}
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
- return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias)
+ return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
}
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
@@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
d.MembershipTable.SelectMembershipFromRoomAndTarget(
- ctx, roomNID, requestSenderUserNID,
+ ctx, nil, roomNID, requestSenderUserNID,
)
if err == sql.ErrNoRows {
// The user has never been a member of that room
@@ -350,13 +411,19 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
+ return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
+}
+
+func (d *Database) getMembershipEventNIDsForRoom(
+ ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
+) ([]types.EventNID, error) {
if joinOnly {
return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
- ctx, roomNID, tables.MembershipStateJoin, localOnly,
+ ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
)
}
- return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
+ return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
}
func (d *Database) GetInvitesForUser(
@@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser(
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
- return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
+ return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
}
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
- eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
+ return d.events(ctx, nil, eventNIDs)
+}
+
+func (d *Database) events(
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
+) ([]types.Event, error) {
+ eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil {
return nil, err
}
- eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
+ eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
var roomNIDs map[types.EventNID]types.RoomNID
- roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
+ roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
if err != nil {
return nil, err
}
@@ -398,7 +471,7 @@ func (d *Database) Events(
}
fetchNIDList = append(fetchNIDList, n)
}
- dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
+ dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil {
return nil, err
}
@@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater(
return updater, err
}
-func (d *Database) GetLatestEventsForUpdate(
- ctx context.Context, roomInfo types.RoomInfo,
-) (*LatestEventsUpdater, error) {
- if d.GetLatestEventsForUpdateFn != nil {
- return d.GetLatestEventsForUpdateFn(ctx, roomInfo)
+func (d *Database) GetRoomUpdater(
+ ctx context.Context, roomInfo *types.RoomInfo,
+) (*RoomUpdater, error) {
+ if d.GetRoomUpdaterFn != nil {
+ return d.GetRoomUpdaterFn(ctx, roomInfo)
}
txn, err := d.DB.Begin()
if err != nil {
return nil, err
}
- var updater *LatestEventsUpdater
+ var updater *RoomUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
- updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
+ updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
return err
})
return updater, err
@@ -462,6 +535,13 @@ func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
+ return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
+}
+
+func (d *Database) storeEvent(
+ ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
+ authEventNIDs []types.EventNID, isRejected bool,
+) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
@@ -472,8 +552,11 @@ func (d *Database) StoreEvent(
redactedEventID string
err error
)
-
- err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ var txn *sql.Tx
+ if updater != nil {
+ txn = updater.txn
+ }
+ err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
@@ -546,42 +629,32 @@ func (d *Database) StoreEvent(
// events updater because it somewhat works as a mutex, ensuring
// that there's a row-level lock on the latest room events (well,
// on Postgres at least).
- var roomInfo *types.RoomInfo
- var updater *LatestEventsUpdater
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
- roomInfo, err = d.RoomInfo(ctx, event.RoomID())
- if err != nil {
- return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
- }
- if roomInfo == nil && len(prevEvents) > 0 {
- return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
- }
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
- updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo)
- if err != nil {
- return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err)
- }
- // Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents
- // and EndTransaction in a writer then it's possible for a new write txn to be made between the two
- // function calls which will then fail with 'database is locked'. This new write txn would HAVE to be
- // something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to
- // SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases
- // as they don't go via InputRoomEvents
- err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error {
- if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
- return fmt.Errorf("updater.StorePreviousEvents: %w", err)
+ succeeded := false
+ if updater == nil {
+ var roomInfo *types.RoomInfo
+ roomInfo, err = d.RoomInfo(ctx, event.RoomID())
+ if err != nil {
+ return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
- succeeded := true
- err = sqlutil.EndTransaction(updater, &succeeded)
- return err
- })
- if err != nil {
- return 0, 0, types.StateAtEvent{}, nil, "", err
+ if roomInfo == nil && len(prevEvents) > 0 {
+ return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
+ }
+ updater, err = d.GetRoomUpdater(ctx, roomInfo)
+ if err != nil {
+ return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
+ }
+ defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
+ }
+ if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
+ return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
}
+ succeeded = true
}
return eventNID, roomNID, types.StateAtEvent{
@@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
}
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
- return d.PublishedTable.SelectAllPublishedRooms(ctx, true)
+ return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
}
func (d *Database) assignRoomNID(
@@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
eventNIDs = append(eventNIDs, e.EventNID)
}
}
- eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
+ eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
// return the event requested
for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
- data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
+ data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil {
return nil, err
}
@@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
}
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
}
- roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
+ roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
}
- roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
+ roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
}
@@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
// isn't a failure.
- eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
+ eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
}
@@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
- eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
+ eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
}
@@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
}
}
- eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
+ eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
- events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
+ events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
}
@@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
- roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
+ roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
- userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
+ userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
if err != nil {
return nil, err
}
@@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
- nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
+ nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
@@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
- return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID)
+ return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
}
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
- return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName)
+ return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
}
// GetKnownUsers searches all users that userID knows about.
@@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
if err != nil {
return nil, err
}
- return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
+ return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
}
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
- return d.RoomsTable.SelectRoomIDs(ctx)
+ return d.RoomsTable.SelectRoomIDs(ctx, nil)
}
// ForgetRoom sets a users room to forgotten
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
- roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID})
+ roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
if err != nil {
return err
}
diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go
index 53b21929..f470ea32 100644
--- a/roomserver/storage/sqlite3/event_json_table.go
+++ b/roomserver/storage/sqlite3/event_json_table.go
@@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON(
}
func (s *eventJSONStatements) BulkSelectEventJSON(
- ctx context.Context, eventNIDs []types.EventNID,
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) {
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
-
- rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
+ var rows *sql.Rows
+ var err error
+ if txn != nil {
+ rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...)
+ } else {
+ rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
+ }
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go
index 62fbce2d..bf12d5b8 100644
--- a/roomserver/storage/sqlite3/event_state_keys_table.go
+++ b/roomserver/storage/sqlite3/event_state_keys_table.go
@@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
- ctx context.Context, eventStateKeys []string,
+ ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
iEventStateKeys := make([]interface{}, len(eventStateKeys))
for k, v := range eventStateKeys {
iEventStateKeys[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
-
- rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
+ var rows *sql.Rows
+ var err error
+ if txn != nil {
+ rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
+ } else {
+ rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
+ }
if err != nil {
return nil, err
}
@@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
- ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
+ ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
-
- rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
+ selectPrep, err := s.db.Prepare(selectOrig)
+ if err != nil {
+ return nil, err
+ }
+ stmt := sqlutil.TxStmt(txn, selectPrep)
+ rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go
index 22df3fb2..f2c9c42f 100644
--- a/roomserver/storage/sqlite3/event_types_table.go
+++ b/roomserver/storage/sqlite3/event_types_table.go
@@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
}
func (s *eventTypeStatements) BulkSelectEventTypeNID(
- ctx context.Context, eventTypes []string,
+ ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
///////////////
iEventTypes := make([]interface{}, len(eventTypes))
@@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
if err != nil {
return nil, err
}
+ stmt := sqlutil.TxStmt(txn, selectPrep)
///////////////
- rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
+ rows, err := stmt.QueryContext(ctx, iEventTypes...)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index 7483e281..e1e6a597 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID(
- ctx context.Context, eventIDs []string,
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
@@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
if err != nil {
return nil, err
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
@@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID(
- ctx context.Context, eventNIDs []types.EventNID,
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
@@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
if err != nil {
return nil, fmt.Errorf("s.db.Prepare: %w", err)
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, params...)
if err != nil {
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
@@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID(
- ctx context.Context, eventIDs []string,
+ ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
@@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
if err != nil {
return nil, err
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil {
@@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
if err != nil {
return nil, err
}
+ selectPrep = sqlutil.TxStmt(txn, selectPrep)
//////////////
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
@@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
}
// bulkSelectEventID returns a map from numeric event ID to string event ID.
-func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
+func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
@@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
if err != nil {
return nil, err
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
@@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
-func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
+func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
@@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
if err != nil {
return nil, err
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil {
@@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
}
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
- ctx context.Context, eventNIDs []types.EventNID,
+ ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) {
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
+ sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs {
iEventNIDs[i] = v
diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go
index c1d7347a..d54d313a 100644
--- a/roomserver/storage/sqlite3/invite_table.go
+++ b/roomserver/storage/sqlite3/invite_table.go
@@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
}
func (s *inviteStatements) InsertInviteEvent(
- ctx context.Context,
- txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
+ ctx context.Context, txn *sql.Tx,
+ inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
@@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
}
func (s *inviteStatements) UpdateInviteRetired(
- ctx context.Context,
- txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
// gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
@@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) {
- rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
+ stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
+ rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID,
)
if err != nil {
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 2e58431d..181b4b4c 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
}
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
- err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
+ stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
+ err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten)
return
}
func (s *membershipStatements) SelectMembershipsFromRoom(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt
@@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else {
selectStmt = s.selectMembershipsFromRoomStmt
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil {
return nil, err
@@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
}
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt
@@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
+ stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
@@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
}
func (s *membershipStatements) SelectRoomsWithMembership(
- ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
+ ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
- rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
+ stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
+ rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil {
return nil, err
}
@@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
-func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
+func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
- rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
+ var rows *sql.Rows
+ var err error
+ if txn != nil {
+ rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
+ } else {
+ rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
+ }
if err != nil {
return nil, err
}
@@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err()
}
-func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
- rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
+func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
+ rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
@@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
}
func (s *membershipStatements) UpdateForgetMembership(
- ctx context.Context,
- txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
+ ctx context.Context, txn *sql.Tx,
+ roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
@@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership(
return err
}
-func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
+func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
var nid types.RoomNID
- err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
+ stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
+ err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
@@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil
}
-func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
+func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
var nid types.RoomNID
- err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
+ stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
+ err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go
index b07c0ac4..9e416ace 100644
--- a/roomserver/storage/sqlite3/published_table.go
+++ b/roomserver/storage/sqlite3/published_table.go
@@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
}
func (s *publishedStatements) SelectPublishedFromRoomID(
- ctx context.Context, roomID string,
+ ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) {
- err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
+ stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
+ err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows {
return false, nil
}
@@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
}
func (s *publishedStatements) SelectAllPublishedRooms(
- ctx context.Context, published bool,
+ ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) {
- rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
+ stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
+ rows, err := stmt.QueryContext(ctx, published)
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go
index 323945b8..7c7bead9 100644
--- a/roomserver/storage/sqlite3/room_aliases_table.go
+++ b/roomserver/storage/sqlite3/room_aliases_table.go
@@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
- ctx context.Context, alias string,
+ ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) {
- err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
+ stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
+ err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows {
return "", nil
}
@@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
}
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
- ctx context.Context, roomID string,
+ ctx context.Context, txn *sql.Tx, roomID string,
) (aliases []string, err error) {
aliases = []string{}
- rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
+ stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
+ rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return
}
@@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
}
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
- ctx context.Context, alias string,
+ ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) {
- err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
+ stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
+ err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows {
return "", nil
}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index c441daec..5413475e 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
-func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
- rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
+func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
+ stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
+ rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
@@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
return roomIDs, nil
}
-func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
+func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDsJSON string
- err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
+ stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
+ err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
)
if err != nil {
@@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
}
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
- ctx context.Context, roomNIDs []types.RoomNID,
+ ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
+ sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
@@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil
}
-func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
+func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
- rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
+ var rows *sql.Rows
+ var err error
+ if txn != nil {
+ rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...)
+ } else {
+ rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
+ }
if err != nil {
return nil, err
}
@@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil
}
-func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
+func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs {
iRoomIDs[i] = v
}
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
- rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
+ var rows *sql.Rows
+ var err error
+ if txn != nil {
+ rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...)
+ } else {
+ rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
+ }
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go
index 58b0b5dc..d51fc492 100644
--- a/roomserver/storage/sqlite3/state_block_table.go
+++ b/roomserver/storage/sqlite3/state_block_table.go
@@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
}
func (s *stateBlockStatements) BulkInsertStateData(
- ctx context.Context,
- txn *sql.Tx,
+ ctx context.Context, txn *sql.Tx,
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
@@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData(
if err != nil {
return 0, fmt.Errorf("json.Marshal: %w", err)
}
- err = s.insertStateDataStmt.QueryRowContext(
+ stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
+ err = stmt.QueryRowContext(
ctx, nids.Hash(), js,
).Scan(&id)
return
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
- ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
+ ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) {
intfs := make([]interface{}, len(stateBlockNIDs))
for i := range stateBlockNIDs {
@@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
if err != nil {
return nil, err
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil {
return nil, err
diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go
index 040d99ae..3c4bde3f 100644
--- a/roomserver/storage/sqlite3/state_snapshot_table.go
+++ b/roomserver/storage/sqlite3/state_snapshot_table.go
@@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
}
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
- ctx context.Context, stateNIDs []types.StateSnapshotNID,
+ ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs {
@@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
if err != nil {
return nil, err
}
+ selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 1fcc7989..325c253b 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
return err
}
d.Database = shared.Database{
- DB: db,
- Cache: cache,
- Writer: sqlutil.NewExclusiveWriter(),
- EventsTable: events,
- EventTypesTable: eventTypes,
- EventStateKeysTable: eventStateKeys,
- EventJSONTable: eventJSON,
- RoomsTable: rooms,
- StateBlockTable: stateBlock,
- StateSnapshotTable: stateSnapshot,
- PrevEventsTable: prevEvents,
- RoomAliasesTable: roomAliases,
- InvitesTable: invites,
- MembershipTable: membership,
- PublishedTable: published,
- RedactionsTable: redactions,
- GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
+ DB: db,
+ Cache: cache,
+ Writer: sqlutil.NewExclusiveWriter(),
+ EventsTable: events,
+ EventTypesTable: eventTypes,
+ EventStateKeysTable: eventStateKeys,
+ EventJSONTable: eventJSON,
+ RoomsTable: rooms,
+ StateBlockTable: stateBlock,
+ StateSnapshotTable: stateSnapshot,
+ PrevEventsTable: prevEvents,
+ RoomAliasesTable: roomAliases,
+ InvitesTable: invites,
+ MembershipTable: membership,
+ PublishedTable: published,
+ RedactionsTable: redactions,
+ GetRoomUpdaterFn: d.GetRoomUpdater,
}
return nil
}
@@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
return false
}
-func (d *Database) GetLatestEventsForUpdate(
- ctx context.Context, roomInfo types.RoomInfo,
-) (*shared.LatestEventsUpdater, error) {
+func (d *Database) GetRoomUpdater(
+ ctx context.Context, roomInfo *types.RoomInfo,
+) (*shared.RoomUpdater, error) {
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
// multiple write transactions on sqlite. The code will perform additional
// write transactions independent of this one which will consistently cause
// 'database is locked' errors. As sqlite doesn't support multi-process on the
// same DB anyway, and we only execute updates sequentially, the only worries
// are for rolling back when things go wrong. (atomicity)
- return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
+ return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
}
func (d *Database) MembershipUpdater(
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index 6ad7ed2e..fed39b94 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -18,20 +18,20 @@ type EventJSONPair struct {
type EventJSON interface {
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
- BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error)
+ BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error)
}
type EventTypes interface {
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
- BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
+ BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error)
}
type EventStateKeys interface {
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
- BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
- BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
+ BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
+ BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
}
type Events interface {
@@ -42,12 +42,12 @@ type Events interface {
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
- BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
- BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
+ BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
+ BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
- BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
+ BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error)
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
@@ -55,12 +55,12 @@ type Events interface {
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
// BulkSelectEventID returns a map from numeric event ID to string event ID.
- BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
+ BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
- BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
+ BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
- SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
+ SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
}
type Rooms interface {
@@ -69,29 +69,29 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
- SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
- SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
- SelectRoomIDs(ctx context.Context) ([]string, error)
- BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
- BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error)
+ SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
+ SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
+ SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
+ BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
+ BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
}
type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
- BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
+ BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
}
type StateBlock interface {
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
- BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
+ BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
}
type RoomAliases interface {
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
- SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error)
- SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
- SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error)
+ SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error)
+ SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error)
+ SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error)
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
}
@@ -106,7 +106,7 @@ type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
- SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
+ SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
}
type MembershipState int64
@@ -121,24 +121,24 @@ const (
type Membership interface {
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
- SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
- SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
- SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
+ SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
+ SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
+ SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
- SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
+ SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined.
- SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
- SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
+ SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
+ SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
- SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
- SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
+ SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
+ SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
}
type Published interface {
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
- SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error)
- SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error)
+ SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
+ SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error)
}
type RedactionInfo struct {