aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-02-24 09:40:20 +0100
committerGitHub <noreply@github.com>2023-02-24 09:40:20 +0100
commitad07b169b8a58b5a843b7b19ff0a989399d0aea0 (patch)
tree8904e5e52ecec90aa94f748c10a08b08cdf01de1 /roomserver/internal
parente6aa0955ff4113114ff8f30073582cc4ecc454fa (diff)
Refactor `StoreEvent` and create a new `RoomDatabase` interface (#2985)
This PR changes a few things: - It pulls out the creation of several NIDs from the `StoreEvent` function to make the functions more reusable - Uses more caching when using those NIDs to avoid DB round trips
Diffstat (limited to 'roomserver/internal')
-rw-r--r--roomserver/internal/alias.go20
-rw-r--r--roomserver/internal/api.go1
-rw-r--r--roomserver/internal/helpers/auth.go23
-rw-r--r--roomserver/internal/helpers/helpers.go22
-rw-r--r--roomserver/internal/helpers/helpers_test.go13
-rw-r--r--roomserver/internal/input/input.go2
-rw-r--r--roomserver/internal/input/input_events.go50
-rw-r--r--roomserver/internal/input/input_membership.go2
-rw-r--r--roomserver/internal/input/input_missing.go10
-rw-r--r--roomserver/internal/perform/perform_admin.go2
-rw-r--r--roomserver/internal/perform/perform_backfill.go54
-rw-r--r--roomserver/internal/perform/perform_inbound_peek.go6
-rw-r--r--roomserver/internal/perform/perform_invite.go4
-rw-r--r--roomserver/internal/perform/perform_unpeek.go5
-rw-r--r--roomserver/internal/query/query.go40
-rw-r--r--roomserver/internal/query/query_test.go2
16 files changed, 139 insertions, 117 deletions
diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go
index 329e6af7..fc61b7f4 100644
--- a/roomserver/internal/alias.go
+++ b/roomserver/internal/alias.go
@@ -30,26 +30,6 @@ import (
"github.com/tidwall/sjson"
)
-// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API.
-type RoomserverInternalAPIDatabase interface {
- // Save a given room alias with the room ID it refers to.
- // Returns an error if there was a problem talking to the database.
- SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
- // Look up the room ID a given alias refers to.
- // Returns an error if there was a problem talking to the database.
- GetRoomIDForAlias(ctx context.Context, alias string) (string, error)
- // Look up all aliases referring to a given room ID.
- // Returns an error if there was a problem talking to the database.
- GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
- // Remove a given room alias.
- // Returns an error if there was a problem talking to the database.
- RemoveRoomAlias(ctx context.Context, alias string) error
- // Look up the room version for a given room.
- GetRoomVersionForRoom(
- ctx context.Context, roomID string,
- ) (gomatrixserverlib.RoomVersion, error)
-}
-
// SetRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) SetRoomAlias(
ctx context.Context,
diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go
index 451b3769..c43b9d04 100644
--- a/roomserver/internal/api.go
+++ b/roomserver/internal/api.go
@@ -155,7 +155,6 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
r.Unpeeker = &perform.Unpeeker{
ServerName: r.ServerName,
Cfg: r.Cfg,
- DB: r.DB,
FSAPI: r.fsAPI,
Inputer: r.Inputer,
}
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index 03d8bca0..27c8dd8f 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -31,7 +31,8 @@ import (
// the soft-fail bool.
func CheckForSoftFail(
ctx context.Context,
- db storage.Database,
+ db storage.RoomDatabase,
+ roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent,
stateEventIDs []string,
) (bool, error) {
@@ -45,16 +46,6 @@ func CheckForSoftFail(
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
}
} else {
- // Work out if the room exists.
- var roomInfo *types.RoomInfo
- roomInfo, err = db.RoomInfo(ctx, event.RoomID())
- if err != nil {
- return false, fmt.Errorf("db.RoomNID: %w", err)
- }
- if roomInfo == nil || roomInfo.IsStub() {
- return false, nil
- }
-
// Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, roomInfo)
@@ -76,7 +67,7 @@ func CheckForSoftFail(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
- authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
+ authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries)
if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err)
}
@@ -93,7 +84,8 @@ func CheckForSoftFail(
// Returns the numeric IDs for the auth events.
func CheckAuthEvents(
ctx context.Context,
- db storage.Database,
+ db storage.RoomDatabase,
+ roomNID types.RoomNID,
event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {
@@ -108,7 +100,7 @@ func CheckAuthEvents(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
- authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
+ authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries)
if err != nil {
return nil, fmt.Errorf("loadAuthEvents: %w", err)
}
@@ -201,6 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
func loadAuthEvents(
ctx context.Context,
db state.StateResolutionStorage,
+ roomNID types.RoomNID,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
@@ -223,7 +216,7 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID)
}
}
- if result.events, err = db.Events(ctx, eventNIDs); err != nil {
+ if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil {
return
}
roomID := ""
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index 7efad7af..ee1610cf 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err
}
- events, err := db.Events(ctx, eventNIDs)
+ events, err := db.Events(ctx, info.RoomNID, eventNIDs)
if err != nil {
return false, err
}
@@ -157,7 +157,7 @@ func IsInvitePending(
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func GetMembershipsAtState(
- ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
+ ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {
var eventNIDs types.EventNIDs
@@ -177,7 +177,7 @@ func GetMembershipsAtState(
util.Unique(eventNIDs)
// Get all of the events in this state
- stateEvents, err := db.Events(ctx, eventNIDs)
+ stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
if err != nil {
return nil, err
}
@@ -220,16 +220,16 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
-func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
+func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}
func LoadEvents(
- ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
+ ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) {
- stateEvents, err := db.Events(ctx, eventNIDs)
+ stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
if err != nil {
return nil, err
}
@@ -242,13 +242,13 @@ func LoadEvents(
}
func LoadStateEvents(
- ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
+ ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry,
) ([]*gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
- return LoadEvents(ctx, db, eventNIDs)
+ return LoadEvents(ctx, db, roomNID, eventNIDs)
}
func CheckServerAllowedToSeeEvent(
@@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState(
return nil, nil
}
- return LoadStateEvents(ctx, db, filteredEntries)
+ return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries)
}
// TODO: Remove this when we have tests to assert correctness of this function
@@ -366,7 +366,7 @@ BFSLoop:
next = make([]string, 0)
}
// Retrieve the events to process from the database.
- events, err = db.EventsFromIDs(ctx, front)
+ events, err = db.EventsFromIDs(ctx, info.RoomNID, front)
if err != nil {
return resultNIDs, redactEventIDs, err
}
@@ -467,7 +467,7 @@ func QueryLatestEventsAndState(
return err
}
- stateEvents, err := LoadStateEvents(ctx, db, stateEntries)
+ stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries)
if err != nil {
return err
}
diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go
index aa5c30e4..62730df1 100644
--- a/roomserver/internal/helpers/helpers_test.go
+++ b/roomserver/internal/helpers/helpers_test.go
@@ -38,7 +38,18 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
var authNIDs []types.EventNID
for _, x := range room.Events() {
- evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false)
+ roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap())
+ assert.NoError(t, err)
+ assert.Greater(t, roomNID, types.RoomNID(0))
+
+ eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type())
+ assert.NoError(t, err)
+ assert.Greater(t, eventTypeNID, types.EventTypeNID(0))
+
+ eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
+ assert.NoError(t, err)
+
+ evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false)
assert.NoError(t, err)
authNIDs = append(authNIDs, evNID)
}
diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go
index 94131103..2ec19f01 100644
--- a/roomserver/internal/input/input.go
+++ b/roomserver/internal/input/input.go
@@ -76,7 +76,7 @@ type Inputer struct {
Cfg *config.RoomServer
Base *base.BaseDendrite
ProcessContext *process.ProcessContext
- DB storage.Database
+ DB storage.RoomDatabase
NATSClient *nats.Conn
JetStream nats.JetStreamContext
Durable nats.SubOpt
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 67edb321..fe35efb2 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -308,10 +308,10 @@ func (r *Inputer) processRoomEvent(
}
var softfail bool
- if input.Kind == api.KindNew {
+ if input.Kind == api.KindNew && !isCreateEvent {
// Check that the event passes authentication checks based on the
// current room state.
- softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
+ softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs)
if err != nil {
logger.WithError(err).Warn("Error authing soft-failed event")
}
@@ -322,8 +322,8 @@ func (r *Inputer) processRoomEvent(
// bother doing this if the event was already rejected as it just ends up
// burning CPU time.
historyVisibility := gomatrixserverlib.HistoryVisibilityShared // Default to shared.
- if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected {
- historyVisibility, rejectionErr, err = r.processStateBefore(ctx, input, missingPrev)
+ if input.Kind != api.KindOutlier && rejectionErr == nil && !isRejected && !isCreateEvent {
+ historyVisibility, rejectionErr, err = r.processStateBefore(ctx, roomInfo.RoomNID, input, missingPrev)
if err != nil {
return fmt.Errorf("r.processStateBefore: %w", err)
}
@@ -332,8 +332,23 @@ func (r *Inputer) processRoomEvent(
}
}
+ roomNID, err := r.DB.GetOrCreateRoomNID(ctx, event)
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
+ }
+
+ eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, event.Type())
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err)
+ }
+
+ eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey())
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err)
+ }
+
// Store the event.
- _, _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, authEventNIDs, isRejected)
+ _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err)
}
@@ -474,6 +489,7 @@ func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixse
// nolint:nakedret
func (r *Inputer) processStateBefore(
ctx context.Context,
+ roomNID types.RoomNID,
input *api.InputRoomEvent,
missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
@@ -489,7 +505,7 @@ func (r *Inputer) processStateBefore(
case input.HasState:
// If we're overriding the state then we need to go and retrieve
// them from the database. It's a hard error if they are missing.
- stateEvents, err := r.DB.EventsFromIDs(ctx, input.StateEventIDs)
+ stateEvents, err := r.DB.EventsFromIDs(ctx, roomNID, input.StateEventIDs)
if err != nil {
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
}
@@ -567,6 +583,7 @@ func (r *Inputer) processStateBefore(
// we've failed to retrieve the auth chain altogether (in which case
// an error is returned) or we've successfully retrieved them all and
// they are now in the database.
+// nolint: gocyclo
func (r *Inputer) fetchAuthEvents(
ctx context.Context,
logger *logrus.Entry,
@@ -587,7 +604,7 @@ func (r *Inputer) fetchAuthEvents(
}
for _, authEventID := range authEventIDs {
- authEvents, err := r.DB.EventsFromIDs(ctx, []string{authEventID})
+ authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo.RoomNID, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
unknown[authEventID] = struct{}{}
continue
@@ -673,8 +690,23 @@ nextAuthEvent:
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
}
+ roomNID, err := r.DB.GetOrCreateRoomNID(ctx, authEvent)
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateRoomNID: %w", err)
+ }
+
+ eventTypeNID, err := r.DB.GetOrCreateEventTypeNID(ctx, authEvent.Type())
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateEventTypeNID: %w", err)
+ }
+
+ eventStateKeyNID, err := r.DB.GetOrCreateEventStateKeyNID(ctx, event.StateKey())
+ if err != nil {
+ return fmt.Errorf("r.DB.GetOrCreateEventStateKeyNID: %w", err)
+ }
+
// Finally, store the event in the database.
- eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, authEventNIDs, isRejected)
+ eventNID, _, _, _, err := r.DB.StoreEvent(ctx, authEvent, roomNID, eventTypeNID, eventStateKeyNID, authEventNIDs, isRejected)
if err != nil {
return fmt.Errorf("updater.StoreEvent: %w", err)
}
@@ -750,7 +782,7 @@ func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event
return err
}
- memberEvents, err := r.DB.Events(ctx, membershipNIDs)
+ memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, membershipNIDs)
if err != nil {
return err
}
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
index 28a54623..99a01255 100644
--- a/roomserver/internal/input/input_membership.go
+++ b/roomserver/internal/input/input_membership.go
@@ -53,7 +53,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
- events, err := updater.Events(ctx, eventNIDs)
+ events, err := updater.Events(ctx, 0, eventNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index 03ac2b38..c8b7d31d 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -43,7 +43,7 @@ type missingStateReq struct {
log *logrus.Entry
virtualHost gomatrixserverlib.ServerName
origin gomatrixserverlib.ServerName
- db storage.Database
+ db storage.RoomDatabase
roomInfo *types.RoomInfo
inputer *Inputer
keys gomatrixserverlib.JSONVerifier
@@ -395,7 +395,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
- stateEvents, err := t.db.Events(ctx, stateEventNIDs)
+ stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomNID, stateEventNIDs)
if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally")
return nil
@@ -432,7 +432,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
missingEventList = append(missingEventList, evID)
}
t.log.WithField("count", len(missingEventList)).Debugf("Fetching missing auth events")
- events, err := t.db.EventsFromIDs(ctx, missingEventList)
+ events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
if err != nil {
return nil
}
@@ -702,7 +702,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
}
t.haveEventsMutex.Unlock()
- events, err := t.db.EventsFromIDs(ctx, missingEventList)
+ events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, missingEventList)
if err != nil {
return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err)
}
@@ -844,7 +844,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
if localFirst {
// fetch from the roomserver
- events, err := t.db.EventsFromIDs(ctx, []string{missingEventID})
+ events, err := t.db.EventsFromIDs(ctx, t.roomInfo.RoomNID, []string{missingEventID})
if err != nil {
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
} else if len(events) == 1 {
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index 3256162b..2efe2255 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -70,7 +70,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil
}
- memberEvents, err := r.DB.Events(ctx, memberNIDs)
+ memberEvents, err := r.DB.Events(ctx, roomInfo.RoomNID, memberNIDs)
if err != nil {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index d9214fdc..3a3a049d 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -86,7 +86,7 @@ func (r *Backfiller) PerformBackfill(
// Retrieve events from the list that was filled previously. If we fail to get
// events from the database then attempt once to get them from federation instead.
var loadedEvents []*gomatrixserverlib.Event
- loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs)
+ loadedEvents, err = helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
if err != nil {
if _, ok := err.(types.MissingEventError); ok {
return r.backfillViaFederation(ctx, request, response)
@@ -258,6 +258,7 @@ type backfillRequester struct {
eventIDToBeforeStateIDs map[string][]string
eventIDMap map[string]*gomatrixserverlib.Event
historyVisiblity gomatrixserverlib.HistoryVisibility
+ roomInfo types.RoomInfo
}
func newBackfillRequester(
@@ -454,14 +455,14 @@ FindSuccessor:
return nil
}
- stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID])
+ stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID)
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return nil
}
// possibly return all joined servers depending on history visiblity
- memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost)
+ memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost)
b.historyVisiblity = visibility
if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
@@ -472,7 +473,7 @@ FindSuccessor:
// Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all
// the servers in it at that moment.
- memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true)
+ memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, info.RoomNID, stateEntries, true)
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return nil
@@ -523,11 +524,15 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
}
eventNIDs := make([]types.EventNID, len(nidMap))
i := 0
+ roomNID := b.roomInfo.RoomNID
for _, nid := range nidMap {
- eventNIDs[i] = nid
+ eventNIDs[i] = nid.EventNID
i++
+ if roomNID == 0 {
+ roomNID = nid.RoomNID
+ }
}
- eventsWithNids, err := b.db.Events(ctx, eventNIDs)
+ eventsWithNids, err := b.db.Events(ctx, roomNID, eventNIDs)
if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err
@@ -544,7 +549,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
// pull all events and then filter by that table.
func joinEventsFromHistoryVisibility(
- ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry,
+ ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
thisServer gomatrixserverlib.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
var eventNIDs []types.EventNID
@@ -557,7 +562,7 @@ func joinEventsFromHistoryVisibility(
}
// Get all of the events in this state
- stateEvents, err := db.Events(ctx, eventNIDs)
+ stateEvents, err := db.Events(ctx, roomInfo.RoomNID, eventNIDs)
if err != nil {
// even though the default should be shared, restricting the visibility to joined
// feels more secure here.
@@ -570,21 +575,17 @@ func joinEventsFromHistoryVisibility(
// Can we see events in the room?
canSeeEvents := auth.IsServerAllowed(thisServer, true, events)
- visibility := gomatrixserverlib.HistoryVisibility(auth.HistoryVisibilityForRoom(events))
+ visibility := auth.HistoryVisibilityForRoom(events)
if !canSeeEvents {
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
return nil, visibility, nil
}
// get joined members
- info, err := db.RoomInfo(ctx, roomID)
- if err != nil {
- return nil, visibility, nil
- }
- joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
+ joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, false)
if err != nil {
return nil, visibility, err
}
- evs, err := db.Events(ctx, joinEventNIDs)
+ evs, err := db.Events(ctx, roomInfo.RoomNID, joinEventNIDs)
return evs, visibility, err
}
@@ -601,12 +602,31 @@ func persistEvents(ctx context.Context, db storage.Database, events []*gomatrixs
authNids := make([]types.EventNID, len(nidMap))
i := 0
for _, nid := range nidMap {
- authNids[i] = nid
+ authNids[i] = nid.EventNID
i++
}
+
+ roomNID, err = db.GetOrCreateRoomNID(ctx, ev.Unwrap())
+ if err != nil {
+ logrus.WithError(err).Error("failed to get or create roomNID")
+ continue
+ }
+
+ eventTypeNID, err := db.GetOrCreateEventTypeNID(ctx, ev.Type())
+ if err != nil {
+ logrus.WithError(err).Error("failed to get or create eventType NID")
+ continue
+ }
+
+ eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(ctx, ev.StateKey())
+ if err != nil {
+ logrus.WithError(err).Error("failed to get or create eventStateKey NID")
+ continue
+ }
+
var redactedEventID string
var redactionEvent *gomatrixserverlib.Event
- eventNID, roomNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), authNids, false)
+ eventNID, _, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), roomNID, eventTypeNID, eventStateKeyNID, authNids, false)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue
diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go
index 29decd36..9ac9edc4 100644
--- a/roomserver/internal/perform/perform_inbound_peek.go
+++ b/roomserver/internal/perform/perform_inbound_peek.go
@@ -29,7 +29,7 @@ import (
)
type InboundPeeker struct {
- DB storage.Database
+ DB storage.RoomDatabase
Inputer *input.Inputer
}
@@ -64,7 +64,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil {
return err
}
- latestEvents, err := r.DB.EventsFromIDs(ctx, []string{latestEventRefs[0].EventID})
+ latestEvents, err := r.DB.EventsFromIDs(ctx, info.RoomNID, []string{latestEventRefs[0].EventID})
if err != nil {
return err
}
@@ -88,7 +88,7 @@ func (r *InboundPeeker) PerformInboundPeek(
if err != nil {
return err
}
- stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, stateEntries)
+ stateEvents, err = helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
if err != nil {
return err
}
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index f60247cd..118e1b87 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -194,7 +194,7 @@ func (r *Inviter) PerformInvite(
// try and see if the user is allowed to make this invite. We can't do
// this for invites coming in over federation - we have to take those on
// trust.
- _, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
+ _, err = helpers.CheckAuthEvents(ctx, r.DB, info.RoomNID, event, event.AuthEventIDs())
if err != nil {
logger.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
"processInviteEvent.checkAuthEvents failed for event",
@@ -291,7 +291,7 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID)
}
- stateEvents, err := db.Events(ctx, stateNIDs)
+ stateEvents, err := db.Events(ctx, info.RoomNID, stateNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/perform/perform_unpeek.go b/roomserver/internal/perform/perform_unpeek.go
index 0d97da4d..4d714be6 100644
--- a/roomserver/internal/perform/perform_unpeek.go
+++ b/roomserver/internal/perform/perform_unpeek.go
@@ -22,7 +22,6 @@ import (
fsAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/input"
- "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -31,9 +30,7 @@ type Unpeeker struct {
ServerName gomatrixserverlib.ServerName
Cfg *config.RoomServer
FSAPI fsAPI.RoomserverFederationAPI
- DB storage.Database
-
- Inputer *input.Inputer
+ Inputer *input.Inputer
}
// PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationapi.
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 1083bb23..ac34e0ff 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -102,7 +102,7 @@ func (r *Queryer) QueryStateAfterEvents(
return err
}
- stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
+ stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
if err != nil {
return err
}
@@ -138,17 +138,7 @@ func (r *Queryer) QueryEventsByID(
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
- eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
- if err != nil {
- return err
- }
-
- var eventNIDs []types.EventNID
- for _, nid := range eventNIDMap {
- eventNIDs = append(eventNIDs, nid)
- }
-
- events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs)
+ events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs)
if err != nil {
return err
}
@@ -196,7 +186,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom
response.HasBeenInRoom = true
- evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID})
+ evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID})
if err != nil {
return err
}
@@ -278,10 +268,10 @@ func (r *Queryer) QueryMembershipAtEvent(
// once. If we have more than one membership event, we need to get the state for each state entry.
if canShortCircuit {
if len(memberships) == 0 {
- memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
+ memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
}
} else {
- memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntry, false)
+ memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
}
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
@@ -328,7 +318,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
- events, err = r.DB.Events(ctx, eventNIDs)
+ events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
if err != nil {
return fmt.Errorf("r.DB.Events: %w", err)
}
@@ -367,14 +357,14 @@ func (r *Queryer) QueryMembershipsForRoom(
return err
}
- events, err = r.DB.Events(ctx, eventNIDs)
+ events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
} else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
- events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
+ events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly)
}
if err != nil {
@@ -425,7 +415,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
- events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID})
+ events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID})
if err != nil {
return
}
@@ -476,7 +466,7 @@ func (r *Queryer) QueryMissingEvents(
eventsToFilter[id] = true
}
}
- events, err := r.DB.EventsFromIDs(ctx, front)
+ events, err := r.DB.EventsFromIDs(ctx, 0, front)
if err != nil {
return err
}
@@ -496,7 +486,7 @@ func (r *Queryer) QueryMissingEvents(
return err
}
- loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs)
+ loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
if err != nil {
return err
}
@@ -621,11 +611,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
return nil, rejected, false, err
}
- events, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
+ events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries)
return events, rejected, false, err
}
-type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
+type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error)
// GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and
@@ -643,7 +633,7 @@ func GetAuthChain(
for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database.
- events, err := fn(ctx, eventsToFetch)
+ events, err := fn(ctx, 0, eventsToFetch)
if err != nil {
return nil, err
}
@@ -981,7 +971,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
- events, err := r.DB.Events(ctx, []types.EventNID{joinNID})
+ events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}
diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go
index 03627ea9..16761157 100644
--- a/roomserver/internal/query/query_test.go
+++ b/roomserver/internal/query/query_test.go
@@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
}
// EventsFromIDs implements RoomserverInternalAPIEventDB
-func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) {
+func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs {
res = append(res, types.Event{
EventNID: 0,