aboutsummaryrefslogtreecommitdiff
path: root/roomserver/state
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/state')
-rw-r--r--roomserver/state/state.go14
1 files changed, 8 insertions, 6 deletions
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index b9c5bbc4..1e776ff6 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -29,6 +29,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
)
@@ -44,20 +45,21 @@ type StateResolutionStorage interface {
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
- GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
}
type StateResolution struct {
db StateResolutionStorage
roomInfo *types.RoomInfo
events map[types.EventNID]gomatrixserverlib.PDU
+ Querier api.QuerySenderIDAPI
}
-func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
+func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution {
return StateResolution{
db: db,
roomInfo: roomInfo,
events: make(map[types.EventNID]gomatrixserverlib.PDU),
+ Querier: querier,
}
}
@@ -947,8 +949,8 @@ func (v *StateResolution) resolveConflictsV1(
}
// Resolve the conflicts.
- resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return v.db.GetUserIDForSender(ctx, roomID, senderID)
+ resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
})
// Map from the full events back to numeric state entries.
@@ -1061,8 +1063,8 @@ func (v *StateResolution) resolveConflictsV2(
conflictedEvents,
nonConflictedEvents,
authEvents,
- func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return v.db.GetUserIDForSender(ctx, roomID, senderID)
+ func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
},
)
}()