diff options
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r-- | roomserver/internal/query/query.go | 50 |
1 files changed, 24 insertions, 26 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 4bd648a9..c74bf21b 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -121,17 +121,16 @@ func (r *Queryer) QueryStateAfterEvents( return fmt.Errorf("getAuthChain: %w", err) } - stateEventsPDU, err := gomatrixserverlib.ResolveConflicts( + stateEvents, err = gomatrixserverlib.ResolveConflicts( info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) } - stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU) } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event}) + response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event}) } return nil @@ -176,7 +175,7 @@ func (r *Queryer) QueryEventsByID( } for _, event := range events { - response.Events = append(response.Events, &types.HeaderedEvent{Event: event.Event}) + response.Events = append(response.Events, &types.HeaderedEvent{PDU: event.PDU}) } return nil @@ -310,7 +309,7 @@ func (r *Queryer) QueryMembershipAtEvent( for i := range memberships { ev := memberships[i] if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) { - response.Membership[eventID] = &types.HeaderedEvent{Event: ev.Event} + response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU} } } } @@ -350,7 +349,7 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -399,7 +398,7 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll) + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -527,7 +526,7 @@ func (r *Queryer) QueryMissingEvents( if _, ok := redactEventIDs[event.EventID()]; ok { event.Redact() } - response.Events = append(response.Events, &types.HeaderedEvent{Event: event}) + response.Events = append(response.Events, &types.HeaderedEvent{PDU: event}) } } @@ -554,18 +553,18 @@ func (r *Queryer) QueryStateAndAuthChain( // the entire current state of the room // TODO: this probably means it should be a different query operation... if request.OnlyFetchAuthChain { - var authEvents []*gomatrixserverlib.Event + var authEvents []gomatrixserverlib.PDU authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs) if err != nil { return err } for _, event := range authEvents { - response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event}) + response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event}) } return nil } - var stateEvents []*gomatrixserverlib.Event + var stateEvents []gomatrixserverlib.PDU stateEvents, rejected, stateMissing, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs) if err != nil { return err @@ -588,28 +587,27 @@ func (r *Queryer) QueryStateAndAuthChain( } if request.ResolveState { - stateEventsPDU, err2 := gomatrixserverlib.ResolveConflicts( + stateEvents, err = gomatrixserverlib.ResolveConflicts( info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), ) - if err2 != nil { - return err2 + if err != nil { + return err } - stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU) } for _, event := range stateEvents { - response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event}) + response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event}) } for _, event := range authEvents { - response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event}) + response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event}) } return err } // first bool: is rejected, second bool: state missing -func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, bool, error) { +func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) { roomState := state.NewStateResolution(r.DB, roomInfo) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { @@ -651,13 +649,13 @@ type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Eve // given events. Will *not* error if we don't have all auth events. func GetAuthChain( ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string, -) ([]*gomatrixserverlib.Event, error) { +) ([]gomatrixserverlib.PDU, error) { // List of event IDs to fetch. On each pass, these events will be requested // from the database and the `eventsToFetch` will be updated with any new // events that we have learned about and need to find. When `eventsToFetch` // is eventually empty, we should have reached the end of the chain. eventsToFetch := authEventIDs - authEventsMap := make(map[string]*gomatrixserverlib.Event) + authEventsMap := make(map[string]gomatrixserverlib.PDU) for len(eventsToFetch) > 0 { // Try to retrieve the events from the database. @@ -673,14 +671,14 @@ func GetAuthChain( for _, event := range events { // Store the event in the event map - this prevents us from requesting it // from the database again. - authEventsMap[event.EventID()] = event.Event + authEventsMap[event.EventID()] = event.PDU // Extract all of the auth events from the newly obtained event. If we // don't already have a record of the event, record it in the list of // events we want to request for the next pass. - for _, authEvent := range event.AuthEvents() { - if _, ok := authEventsMap[authEvent.EventID]; !ok { - eventsToFetch = append(eventsToFetch, authEvent.EventID) + for _, authEventID := range event.AuthEventIDs() { + if _, ok := authEventsMap[authEventID]; !ok { + eventsToFetch = append(eventsToFetch, authEventID) } } } @@ -688,7 +686,7 @@ func GetAuthChain( // We've now retrieved all of the events we can. Flatten them down into an // array and return them. - var authEvents []*gomatrixserverlib.Event + var authEvents []gomatrixserverlib.PDU for _, event := range authEventsMap { authEvents = append(authEvents, event) } @@ -854,7 +852,7 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq } hchain := make([]*types.HeaderedEvent, len(chain)) for i := range chain { - hchain[i] = &types.HeaderedEvent{Event: chain[i]} + hchain[i] = &types.HeaderedEvent{PDU: chain[i]} } res.AuthChain = hchain return nil |