aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r--roomserver/internal/query/query.go50
1 files changed, 24 insertions, 26 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 4bd648a9..c74bf21b 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -121,17 +121,16 @@ func (r *Queryer) QueryStateAfterEvents(
return fmt.Errorf("getAuthChain: %w", err)
}
- stateEventsPDU, err := gomatrixserverlib.ResolveConflicts(
+ stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
)
if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
}
- stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU)
}
for _, event := range stateEvents {
- response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
+ response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event})
}
return nil
@@ -176,7 +175,7 @@ func (r *Queryer) QueryEventsByID(
}
for _, event := range events {
- response.Events = append(response.Events, &types.HeaderedEvent{Event: event.Event})
+ response.Events = append(response.Events, &types.HeaderedEvent{PDU: event.PDU})
}
return nil
@@ -310,7 +309,7 @@ func (r *Queryer) QueryMembershipAtEvent(
for i := range memberships {
ev := memberships[i]
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) {
- response.Membership[eventID] = &types.HeaderedEvent{Event: ev.Event}
+ response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
}
}
}
@@ -350,7 +349,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll)
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
@@ -399,7 +398,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll)
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
@@ -527,7 +526,7 @@ func (r *Queryer) QueryMissingEvents(
if _, ok := redactEventIDs[event.EventID()]; ok {
event.Redact()
}
- response.Events = append(response.Events, &types.HeaderedEvent{Event: event})
+ response.Events = append(response.Events, &types.HeaderedEvent{PDU: event})
}
}
@@ -554,18 +553,18 @@ func (r *Queryer) QueryStateAndAuthChain(
// the entire current state of the room
// TODO: this probably means it should be a different query operation...
if request.OnlyFetchAuthChain {
- var authEvents []*gomatrixserverlib.Event
+ var authEvents []gomatrixserverlib.PDU
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
if err != nil {
return err
}
for _, event := range authEvents {
- response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event})
+ response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event})
}
return nil
}
- var stateEvents []*gomatrixserverlib.Event
+ var stateEvents []gomatrixserverlib.PDU
stateEvents, rejected, stateMissing, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs)
if err != nil {
return err
@@ -588,28 +587,27 @@ func (r *Queryer) QueryStateAndAuthChain(
}
if request.ResolveState {
- stateEventsPDU, err2 := gomatrixserverlib.ResolveConflicts(
+ stateEvents, err = gomatrixserverlib.ResolveConflicts(
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
)
- if err2 != nil {
- return err2
+ if err != nil {
+ return err
}
- stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU)
}
for _, event := range stateEvents {
- response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
+ response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event})
}
for _, event := range authEvents {
- response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event})
+ response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event})
}
return err
}
// first bool: is rejected, second bool: state missing
-func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, bool, error) {
+func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
roomState := state.NewStateResolution(r.DB, roomInfo)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil {
@@ -651,13 +649,13 @@ type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Eve
// given events. Will *not* error if we don't have all auth events.
func GetAuthChain(
ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
-) ([]*gomatrixserverlib.Event, error) {
+) ([]gomatrixserverlib.PDU, error) {
// List of event IDs to fetch. On each pass, these events will be requested
// from the database and the `eventsToFetch` will be updated with any new
// events that we have learned about and need to find. When `eventsToFetch`
// is eventually empty, we should have reached the end of the chain.
eventsToFetch := authEventIDs
- authEventsMap := make(map[string]*gomatrixserverlib.Event)
+ authEventsMap := make(map[string]gomatrixserverlib.PDU)
for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database.
@@ -673,14 +671,14 @@ func GetAuthChain(
for _, event := range events {
// Store the event in the event map - this prevents us from requesting it
// from the database again.
- authEventsMap[event.EventID()] = event.Event
+ authEventsMap[event.EventID()] = event.PDU
// Extract all of the auth events from the newly obtained event. If we
// don't already have a record of the event, record it in the list of
// events we want to request for the next pass.
- for _, authEvent := range event.AuthEvents() {
- if _, ok := authEventsMap[authEvent.EventID]; !ok {
- eventsToFetch = append(eventsToFetch, authEvent.EventID)
+ for _, authEventID := range event.AuthEventIDs() {
+ if _, ok := authEventsMap[authEventID]; !ok {
+ eventsToFetch = append(eventsToFetch, authEventID)
}
}
}
@@ -688,7 +686,7 @@ func GetAuthChain(
// We've now retrieved all of the events we can. Flatten them down into an
// array and return them.
- var authEvents []*gomatrixserverlib.Event
+ var authEvents []gomatrixserverlib.PDU
for _, event := range authEventsMap {
authEvents = append(authEvents, event)
}
@@ -854,7 +852,7 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
}
hchain := make([]*types.HeaderedEvent, len(chain))
for i := range chain {
- hchain[i] = &types.HeaderedEvent{Event: chain[i]}
+ hchain[i] = &types.HeaderedEvent{PDU: chain[i]}
}
res.AuthChain = hchain
return nil