aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query')
-rw-r--r--roomserver/internal/query/query.go14
-rw-r--r--roomserver/internal/query/query_test.go4
2 files changed, 10 insertions, 8 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 7346c7a7..2a361641 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -107,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents(
}
authEventIDs = util.UniqueStrings(authEventIDs)
- authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
+ authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
if err != nil {
return fmt.Errorf("getAuthChain: %w", err)
}
@@ -447,10 +447,12 @@ func (r *Queryer) QueryStateAndAuthChain(
response.RoomExists = true
response.RoomVersion = info.RoomVersion
- stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs)
+ var stateEvents []*gomatrixserverlib.Event
+ stateEvents, err = r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs)
if err != nil {
return err
}
+
response.PrevEventsExist = true
// add the auth event IDs for the current state events too
@@ -461,7 +463,7 @@ func (r *Queryer) QueryStateAndAuthChain(
}
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
- authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
+ authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
if err != nil {
return err
}
@@ -510,11 +512,11 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomIn
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
-// getAuthChain fetches the auth chain for the given auth events. An auth chain
+// GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and
// all their auth_events, recursively. The returned set of events contain the
// given events. Will *not* error if we don't have all auth events.
-func getAuthChain(
+func GetAuthChain(
ctx context.Context, fn eventsFromIDs, authEventIDs []string,
) ([]*gomatrixserverlib.Event, error) {
// List of event IDs to fetch. On each pass, these events will be requested
@@ -718,7 +720,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
}
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
- chain, err := getAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
+ chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
if err != nil {
return err
}
diff --git a/roomserver/internal/query/query_test.go b/roomserver/internal/query/query_test.go
index 4e761d8e..ba5bb9f5 100644
--- a/roomserver/internal/query/query_test.go
+++ b/roomserver/internal/query/query_test.go
@@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err)
}
- result, err := getAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"})
+ result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
@@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err)
}
- result, err := getAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"})
+ result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}