aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-05-08 19:25:44 +0000
committerGitHub <noreply@github.com>2023-05-08 19:25:44 +0000
commita49c9f01e227aeb12aa2f27d5bf1915453c23a3b (patch)
treee8de0a2dcb4f5f57094dd9024cc6b94793a080ae
parent2b34f88fde6b3aa633c8f23ff424a5db4951efd3 (diff)
Only require room version instead of room info for db.Events() (#3079)
This reduces the API requirements for the Events database to align with what is actually required.
-rw-r--r--cmd/resolve-state/main.go4
-rw-r--r--roomserver/internal/helpers/auth.go7
-rw-r--r--roomserver/internal/helpers/helpers.go12
-rw-r--r--roomserver/internal/input/input_events.go5
-rw-r--r--roomserver/internal/input/input_membership.go2
-rw-r--r--roomserver/internal/input/input_missing.go5
-rw-r--r--roomserver/internal/perform/perform_admin.go2
-rw-r--r--roomserver/internal/perform/perform_backfill.go9
-rw-r--r--roomserver/internal/perform/perform_invite.go5
-rw-r--r--roomserver/internal/query/query.go8
-rw-r--r--roomserver/roomserver_test.go2
-rw-r--r--roomserver/state/state.go13
-rw-r--r--roomserver/storage/interface.go4
-rw-r--r--roomserver/storage/shared/room_updater.go7
-rw-r--r--roomserver/storage/shared/storage.go22
-rw-r--r--roomserver/types/types.go3
16 files changed, 74 insertions, 36 deletions
diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go
index 1278b1cc..3a4255ba 100644
--- a/cmd/resolve-state/main.go
+++ b/cmd/resolve-state/main.go
@@ -91,7 +91,7 @@ func main() {
}
var eventEntries []types.Event
- eventEntries, err = roomserverDB.Events(ctx, roomInfo, eventNIDs)
+ eventEntries, err = roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
panic(err)
}
@@ -149,7 +149,7 @@ func main() {
}
fmt.Println("Fetching", len(eventNIDMap), "state events")
- eventEntries, err := roomserverDB.Events(ctx, roomInfo, eventNIDs)
+ eventEntries, err := roomserverDB.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
panic(err)
}
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index 3d2beab3..24958091 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -219,7 +219,12 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID)
}
}
- if result.events, err = db.Events(ctx, roomInfo, eventNIDs); err != nil {
+
+ if roomInfo == nil {
+ err = types.ErrorInvalidRoomInfo
+ return
+ }
+ if result.events, err = db.Events(ctx, roomInfo.RoomVersion, eventNIDs); err != nil {
return
}
roomID := ""
diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go
index ea0074fc..95397cd5 100644
--- a/roomserver/internal/helpers/helpers.go
+++ b/roomserver/internal/helpers/helpers.go
@@ -86,7 +86,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err
}
- events, err := db.Events(ctx, info, eventNIDs)
+ events, err := db.Events(ctx, info.RoomVersion, eventNIDs)
if err != nil {
return false, err
}
@@ -183,7 +183,10 @@ func GetMembershipsAtState(
util.Unique(eventNIDs)
// Get all of the events in this state
- stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
+ if roomInfo == nil {
+ return nil, types.ErrorInvalidRoomInfo
+ }
+ stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
return nil, err
}
@@ -235,7 +238,10 @@ func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types
func LoadEvents(
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.PDU, error) {
- stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
+ if roomInfo == nil {
+ return nil, types.ErrorInvalidRoomInfo
+ }
+ stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 9ae29c54..c8f5737f 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -805,7 +805,10 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
return err
}
- memberEvents, err := r.DB.Events(ctx, roomInfo, membershipNIDs)
+ if roomInfo == nil {
+ return types.ErrorInvalidRoomInfo
+ }
+ memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, membershipNIDs)
if err != nil {
return err
}
diff --git a/roomserver/internal/input/input_membership.go b/roomserver/internal/input/input_membership.go
index 947f6c15..98d7d13b 100644
--- a/roomserver/internal/input/input_membership.go
+++ b/roomserver/internal/input/input_membership.go
@@ -55,7 +55,7 @@ func (r *Inputer) updateMemberships(
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
- events, err := updater.Events(ctx, nil, eventNIDs)
+ events, err := updater.Events(ctx, "", eventNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index 89ba0756..8a123522 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -398,7 +398,10 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
for _, entry := range stateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
- stateEvents, err := t.db.Events(ctx, t.roomInfo, stateEventNIDs)
+ if t.roomInfo == nil {
+ return nil
+ }
+ stateEvents, err := t.db.Events(ctx, t.roomInfo.RoomVersion, stateEventNIDs)
if err != nil {
t.log.WithError(err).Warnf("failed to load state events locally")
return nil
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index 70668a20..375eefbe 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -60,7 +60,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
return nil, err
}
- memberEvents, err := r.DB.Events(ctx, roomInfo, memberNIDs)
+ memberEvents, err := r.DB.Events(ctx, roomInfo.RoomVersion, memberNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index 8dbfad9b..fb579f03 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -533,7 +533,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
roomNID = nid.RoomNID
}
}
- eventsWithNids, err := b.db.Events(ctx, &b.roomInfo, eventNIDs)
+ eventsWithNids, err := b.db.Events(ctx, b.roomInfo.RoomVersion, eventNIDs)
if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err
@@ -563,7 +563,10 @@ func joinEventsFromHistoryVisibility(
}
// Get all of the events in this state
- stateEvents, err := db.Events(ctx, roomInfo, eventNIDs)
+ if roomInfo == nil {
+ return nil, gomatrixserverlib.HistoryVisibilityJoined, types.ErrorInvalidRoomInfo
+ }
+ stateEvents, err := db.Events(ctx, roomInfo.RoomVersion, eventNIDs)
if err != nil {
// even though the default should be shared, restricting the visibility to joined
// feels more secure here.
@@ -586,7 +589,7 @@ func joinEventsFromHistoryVisibility(
if err != nil {
return nil, visibility, err
}
- evs, err := db.Events(ctx, roomInfo, joinEventNIDs)
+ evs, err := db.Events(ctx, roomInfo.RoomVersion, joinEventNIDs)
return evs, visibility, err
}
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index a920811d..db0b53fe 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -269,7 +269,10 @@ func buildInviteStrippedState(
for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID)
}
- stateEvents, err := db.Events(ctx, info, stateNIDs)
+ if info == nil {
+ return nil, types.ErrorInvalidRoomInfo
+ }
+ stateEvents, err := db.Events(ctx, info.RoomVersion, stateNIDs)
if err != nil {
return nil, err
}
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index c74bf21b..27c0dd0c 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -212,7 +212,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom
response.HasBeenInRoom = true
- evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
+ evs, err := r.DB.Events(ctx, info.RoomVersion, []types.EventNID{membershipEventNID})
if err != nil {
return err
}
@@ -344,7 +344,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
- events, err = r.DB.Events(ctx, info, eventNIDs)
+ events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
if err != nil {
return fmt.Errorf("r.DB.Events: %w", err)
}
@@ -383,7 +383,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return err
}
- events, err = r.DB.Events(ctx, info, eventNIDs)
+ events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
} else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil {
@@ -967,7 +967,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
- events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
+ events, err := r.DB.Events(ctx, roomInfo.RoomVersion, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}
diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go
index c0f3e12d..d19ebebe 100644
--- a/roomserver/roomserver_test.go
+++ b/roomserver/roomserver_test.go
@@ -571,7 +571,7 @@ func TestRedaction(t *testing.T) {
if ev.Type() == spec.MRoomRedaction {
nids, err := db.EventNIDs(ctx, []string{ev.Redacts()})
assert.NoError(t, err)
- evs, err := db.Events(ctx, roomInfo, []types.EventNID{nids[ev.Redacts()].EventNID})
+ evs, err := db.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[ev.Redacts()].EventNID})
assert.NoError(t, err)
assert.Equal(t, 1, len(evs))
assert.Equal(t, tc.wantRedacted, evs[0].Redacted())
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index b2a8a8d9..f38d8f96 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -41,7 +41,7 @@ type StateResolutionStorage interface {
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
- Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
+ Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
}
@@ -85,7 +85,10 @@ func (p *StateResolution) Resolve(ctx context.Context, eventID string) (*gomatri
return nil, fmt.Errorf("unable to find power level event")
}
- events, err := p.db.Events(ctx, p.roomInfo, []types.EventNID{plNID})
+ if p.roomInfo == nil {
+ return nil, types.ErrorInvalidRoomInfo
+ }
+ events, err := p.db.Events(ctx, p.roomInfo.RoomVersion, []types.EventNID{plNID})
if err != nil {
return nil, err
}
@@ -1134,7 +1137,11 @@ func (v *StateResolution) loadStateEvents(
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
- events, err := v.db.Events(ctx, v.roomInfo, eventNIDs)
+
+ if v.roomInfo == nil {
+ return nil, nil, types.ErrorInvalidRoomInfo
+ }
+ events, err := v.db.Events(ctx, v.roomInfo.RoomVersion, eventNIDs)
if err != nil {
return nil, nil, err
}
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 8da6b350..6bc4ce9a 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -72,7 +72,7 @@ type Database interface {
) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events.
- Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
+ Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
@@ -224,7 +224,7 @@ type EventDatabase interface {
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
- Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error)
+ Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
// MaybeRedactEvent returns the redaction event and the redacted event if this call resulted in a redaction, else an error
// (nil if there was nothing to do)
MaybeRedactEvent(
diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go
index dc1db082..5a20c67b 100644
--- a/roomserver/storage/shared/room_updater.go
+++ b/roomserver/storage/shared/room_updater.go
@@ -116,8 +116,11 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
})
}
-func (u *RoomUpdater) Events(ctx context.Context, _ *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
- return u.d.events(ctx, u.txn, u.roomInfo, eventNIDs)
+func (u *RoomUpdater) Events(ctx context.Context, _ gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) {
+ if u.roomInfo == nil {
+ return nil, types.ErrorInvalidRoomInfo
+ }
+ return u.d.events(ctx, u.txn, u.roomInfo.RoomVersion, eventNIDs)
}
func (u *RoomUpdater) SnapshotNIDFromEventID(
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index aa8e7341..60e46c47 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -392,7 +392,10 @@ func (d *EventDatabase) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomInfo
nids = append(nids, nid.EventNID)
}
- return d.events(ctx, txn, roomInfo, nids)
+ if roomInfo == nil {
+ return nil, types.ErrorInvalidRoomInfo
+ }
+ return d.events(ctx, txn, roomInfo.RoomVersion, nids)
}
func (d *Database) LatestEventIDs(
@@ -531,17 +534,13 @@ func (d *Database) GetInvitesForUser(
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
}
-func (d *EventDatabase) Events(ctx context.Context, roomInfo *types.RoomInfo, eventNIDs []types.EventNID) ([]types.Event, error) {
- return d.events(ctx, nil, roomInfo, eventNIDs)
+func (d *EventDatabase) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) {
+ return d.events(ctx, nil, roomVersion, eventNIDs)
}
func (d *EventDatabase) events(
- ctx context.Context, txn *sql.Tx, roomInfo *types.RoomInfo, inputEventNIDs types.EventNIDs,
+ ctx context.Context, txn *sql.Tx, roomVersion gomatrixserverlib.RoomVersion, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) {
- if roomInfo == nil { // this should never happen
- return nil, fmt.Errorf("unable to parse events without roomInfo")
- }
-
sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]gomatrixserverlib.PDU, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@@ -579,7 +578,7 @@ func (d *EventDatabase) events(
eventIDs = map[types.EventNID]string{}
}
- verImpl, err := gomatrixserverlib.GetRoomVersion(roomInfo.RoomVersion)
+ verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil {
return nil, err
}
@@ -1107,7 +1106,10 @@ func (d *EventDatabase) loadEvent(ctx context.Context, roomInfo *types.RoomInfo,
if len(nids) == 0 {
return nil
}
- evs, err := d.Events(ctx, roomInfo, []types.EventNID{nids[eventID].EventNID})
+ if roomInfo == nil {
+ return nil
+ }
+ evs, err := d.Events(ctx, roomInfo.RoomVersion, []types.EventNID{nids[eventID].EventNID})
if err != nil {
return nil
}
diff --git a/roomserver/types/types.go b/roomserver/types/types.go
index 34934585..e986b9da 100644
--- a/roomserver/types/types.go
+++ b/roomserver/types/types.go
@@ -17,6 +17,7 @@ package types
import (
"encoding/json"
+ "fmt"
"sort"
"strings"
"sync"
@@ -328,3 +329,5 @@ func (r *RoomInfo) CopyFrom(r2 *RoomInfo) {
r.stateSnapshotNID = r2.stateSnapshotNID
r.isStub = r2.isStub
}
+
+var ErrorInvalidRoomInfo = fmt.Errorf("room info is invalid")