From f5b3144dc33ddcb2ab161323d422cab257d04b4c Mon Sep 17 00:00:00 2001
From: kegsay <kegan@matrix.org>
Date: Tue, 2 May 2023 15:03:16 +0100
Subject: Use PDU not *Event in HeaderedEvent (#3073)

Requires https://github.com/matrix-org/gomatrixserverlib/pull/376

This has numerous upsides:
 - Less type casting to `*Event` is required.
- Making Dendrite work with `PDU` interfaces means we can swap out Event
impls more easily.
 - Tests which represent weird event shapes are easier to write.

Part of a series of refactors on GMSL.
---
 roomserver/internal/query/query.go | 50 ++++++++++++++++++--------------------
 1 file changed, 24 insertions(+), 26 deletions(-)

(limited to 'roomserver/internal/query/query.go')

diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 4bd648a9..c74bf21b 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -121,17 +121,16 @@ func (r *Queryer) QueryStateAfterEvents(
 			return fmt.Errorf("getAuthChain: %w", err)
 		}
 
-		stateEventsPDU, err := gomatrixserverlib.ResolveConflicts(
+		stateEvents, err = gomatrixserverlib.ResolveConflicts(
 			info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
 		)
 		if err != nil {
 			return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
 		}
-		stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU)
 	}
 
 	for _, event := range stateEvents {
-		response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
+		response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event})
 	}
 
 	return nil
@@ -176,7 +175,7 @@ func (r *Queryer) QueryEventsByID(
 	}
 
 	for _, event := range events {
-		response.Events = append(response.Events, &types.HeaderedEvent{Event: event.Event})
+		response.Events = append(response.Events, &types.HeaderedEvent{PDU: event.PDU})
 	}
 
 	return nil
@@ -310,7 +309,7 @@ func (r *Queryer) QueryMembershipAtEvent(
 		for i := range memberships {
 			ev := memberships[i]
 			if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) {
-				response.Membership[eventID] = &types.HeaderedEvent{Event: ev.Event}
+				response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
 			}
 		}
 	}
@@ -350,7 +349,7 @@ func (r *Queryer) QueryMembershipsForRoom(
 			return fmt.Errorf("r.DB.Events: %w", err)
 		}
 		for _, event := range events {
-			clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll)
+			clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
 			response.JoinEvents = append(response.JoinEvents, clientEvent)
 		}
 		return nil
@@ -399,7 +398,7 @@ func (r *Queryer) QueryMembershipsForRoom(
 	}
 
 	for _, event := range events {
-		clientEvent := synctypes.ToClientEvent(event.Event, synctypes.FormatAll)
+		clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
 		response.JoinEvents = append(response.JoinEvents, clientEvent)
 	}
 
@@ -527,7 +526,7 @@ func (r *Queryer) QueryMissingEvents(
 			if _, ok := redactEventIDs[event.EventID()]; ok {
 				event.Redact()
 			}
-			response.Events = append(response.Events, &types.HeaderedEvent{Event: event})
+			response.Events = append(response.Events, &types.HeaderedEvent{PDU: event})
 		}
 	}
 
@@ -554,18 +553,18 @@ func (r *Queryer) QueryStateAndAuthChain(
 	// the entire current state of the room
 	// TODO: this probably means it should be a different query operation...
 	if request.OnlyFetchAuthChain {
-		var authEvents []*gomatrixserverlib.Event
+		var authEvents []gomatrixserverlib.PDU
 		authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
 		if err != nil {
 			return err
 		}
 		for _, event := range authEvents {
-			response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event})
+			response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event})
 		}
 		return nil
 	}
 
-	var stateEvents []*gomatrixserverlib.Event
+	var stateEvents []gomatrixserverlib.PDU
 	stateEvents, rejected, stateMissing, err := r.loadStateAtEventIDs(ctx, info, request.PrevEventIDs)
 	if err != nil {
 		return err
@@ -588,28 +587,27 @@ func (r *Queryer) QueryStateAndAuthChain(
 	}
 
 	if request.ResolveState {
-		stateEventsPDU, err2 := gomatrixserverlib.ResolveConflicts(
+		stateEvents, err = gomatrixserverlib.ResolveConflicts(
 			info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
 		)
-		if err2 != nil {
-			return err2
+		if err != nil {
+			return err
 		}
-		stateEvents = gomatrixserverlib.TempCastToEvents(stateEventsPDU)
 	}
 
 	for _, event := range stateEvents {
-		response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{Event: event})
+		response.StateEvents = append(response.StateEvents, &types.HeaderedEvent{PDU: event})
 	}
 
 	for _, event := range authEvents {
-		response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{Event: event})
+		response.AuthChainEvents = append(response.AuthChainEvents, &types.HeaderedEvent{PDU: event})
 	}
 
 	return err
 }
 
 // first bool: is rejected, second bool: state missing
-func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, bool, bool, error) {
+func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
 	roomState := state.NewStateResolution(r.DB, roomInfo)
 	prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
 	if err != nil {
@@ -651,13 +649,13 @@ type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Eve
 // given events. Will *not* error if we don't have all auth events.
 func GetAuthChain(
 	ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
-) ([]*gomatrixserverlib.Event, error) {
+) ([]gomatrixserverlib.PDU, error) {
 	// List of event IDs to fetch. On each pass, these events will be requested
 	// from the database and the `eventsToFetch` will be updated with any new
 	// events that we have learned about and need to find. When `eventsToFetch`
 	// is eventually empty, we should have reached the end of the chain.
 	eventsToFetch := authEventIDs
-	authEventsMap := make(map[string]*gomatrixserverlib.Event)
+	authEventsMap := make(map[string]gomatrixserverlib.PDU)
 
 	for len(eventsToFetch) > 0 {
 		// Try to retrieve the events from the database.
@@ -673,14 +671,14 @@ func GetAuthChain(
 		for _, event := range events {
 			// Store the event in the event map - this prevents us from requesting it
 			// from the database again.
-			authEventsMap[event.EventID()] = event.Event
+			authEventsMap[event.EventID()] = event.PDU
 
 			// Extract all of the auth events from the newly obtained event. If we
 			// don't already have a record of the event, record it in the list of
 			// events we want to request for the next pass.
-			for _, authEvent := range event.AuthEvents() {
-				if _, ok := authEventsMap[authEvent.EventID]; !ok {
-					eventsToFetch = append(eventsToFetch, authEvent.EventID)
+			for _, authEventID := range event.AuthEventIDs() {
+				if _, ok := authEventsMap[authEventID]; !ok {
+					eventsToFetch = append(eventsToFetch, authEventID)
 				}
 			}
 		}
@@ -688,7 +686,7 @@ func GetAuthChain(
 
 	// We've now retrieved all of the events we can. Flatten them down into an
 	// array and return them.
-	var authEvents []*gomatrixserverlib.Event
+	var authEvents []gomatrixserverlib.PDU
 	for _, event := range authEventsMap {
 		authEvents = append(authEvents, event)
 	}
@@ -854,7 +852,7 @@ func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainReq
 	}
 	hchain := make([]*types.HeaderedEvent, len(chain))
 	for i := range chain {
-		hchain[i] = &types.HeaderedEvent{Event: chain[i]}
+		hchain[i] = &types.HeaderedEvent{PDU: chain[i]}
 	}
 	res.AuthChain = hchain
 	return nil
-- 
cgit v1.2.3