aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r--roomserver/internal/query/query.go130
1 files changed, 78 insertions, 52 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index ac34e0ff..c5b74422 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -21,11 +21,12 @@ import (
"errors"
"fmt"
- "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
+ "github.com/matrix-org/dendrite/roomserver/storage/tables"
+
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls"
@@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents(
return err
}
- stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
+ stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil {
return err
}
@@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents(
}
authEventIDs = util.UniqueStrings(authEventIDs)
- authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
+ authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil {
return fmt.Errorf("getAuthChain: %w", err)
}
@@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents(
return nil
}
-// QueryEventsByID implements api.RoomserverInternalAPI
+// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
+// which room to use by querying the first events roomID.
func (r *Queryer) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
- events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs)
+ if len(request.EventIDs) == 0 {
+ return nil
+ }
+ var err error
+ // We didn't receive a room ID, we need to fetch it first before we can continue.
+ // This happens for e.g. ` /_matrix/federation/v1/event/{eventId}`
+ var roomInfo *types.RoomInfo
+ if request.RoomID == "" {
+ var eventNIDs map[string]types.EventMetadata
+ eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]})
+ if err != nil {
+ return err
+ }
+ if len(eventNIDs) == 0 {
+ return nil
+ }
+ roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID)
+ } else {
+ roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID)
+ }
+ if err != nil {
+ return err
+ }
+ if roomInfo == nil {
+ return nil
+ }
+ events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs)
if err != nil {
return err
}
for _, event := range events {
- roomVersion, verr := r.roomVersion(event.RoomID())
- if verr != nil {
- return verr
- }
-
- response.Events = append(response.Events, event.Headered(roomVersion))
+ response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion))
}
return nil
@@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom
response.HasBeenInRoom = true
- evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID})
+ evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
if err != nil {
return err
}
@@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent(
// once. If we have more than one membership event, we need to get the state for each state entry.
if canShortCircuit {
if len(memberships) == 0 {
- memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
+ memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
} else {
- memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
+ memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
@@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
- events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
+ events, err = r.DB.Events(ctx, info, eventNIDs)
if err != nil {
return fmt.Errorf("r.DB.Events: %w", err)
}
@@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom(
return err
}
- events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
+ events, err = r.DB.Events(ctx, info, eventNIDs)
} else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
- events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly)
+ events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly)
}
if err != nil {
@@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom(
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context,
- request *api.QueryServerAllowedToSeeEventRequest,
- response *api.QueryServerAllowedToSeeEventResponse,
-) (err error) {
- events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID})
+ serverName gomatrixserverlib.ServerName,
+ eventID string,
+) (allowed bool, err error) {
+ events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil {
return
}
if len(events) == 0 {
- response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
- return
- }
- roomID := events[0].RoomID()
-
- inRoomReq := &api.QueryServerJoinedToRoomRequest{
- RoomID: roomID,
- ServerName: request.ServerName,
+ return allowed, nil
}
- inRoomRes := &api.QueryServerJoinedToRoomResponse{}
- if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
- return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
- }
-
- info, err := r.DB.RoomInfo(ctx, roomID)
+ info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID)
if err != nil {
- return err
+ return allowed, err
}
if info == nil || info.IsStub() {
- return nil
+ return allowed, nil
}
- response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
- ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
+ var isInRoom bool
+ if r.IsLocalServerName(serverName) || serverName == "" {
+ isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
+ if err != nil {
+ return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)
+ }
+ } else {
+ isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName)
+ if err != nil {
+ return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err)
+ }
+ }
+
+ return helpers.CheckServerAllowedToSeeEvent(
+ ctx, r.DB, info, eventID, serverName, isInRoom,
)
- return
}
// QueryMissingEvents implements api.RoomserverInternalAPI
@@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents(
eventsToFilter[id] = true
}
}
- events, err := r.DB.EventsFromIDs(ctx, 0, front)
+ if len(front) == 0 {
+ return nil // no events to query, give up.
+ }
+ events, err := r.DB.EventNIDs(ctx, []string{front[0]})
if err != nil {
return err
}
if len(events) == 0 {
return nil // we are missing the events being asked to search from, give up.
}
- info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
+ info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID)
if err != nil {
return err
}
if info == nil || info.IsStub() {
- return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
+ return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
}
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
@@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents(
return err
}
- loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
+ loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil {
return err
}
@@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// TODO: this probably means it should be a different query operation...
if request.OnlyFetchAuthChain {
var authEvents []*gomatrixserverlib.Event
- authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs)
+ authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
if err != nil {
return err
}
@@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain(
}
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
- authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
+ authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil {
return err
}
@@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
return nil, rejected, false, err
}
- events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries)
+ events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries)
return events, rejected, false, err
}
-type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error)
+type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error)
// GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and
// all their auth_events, recursively. The returned set of events contain the
// given events. Will *not* error if we don't have all auth events.
func GetAuthChain(
- ctx context.Context, fn eventsFromIDs, authEventIDs []string,
+ ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
) ([]*gomatrixserverlib.Event, error) {
// List of event IDs to fetch. On each pass, these events will be requested
// from the database and the `eventsToFetch` will be updated with any new
@@ -633,7 +659,7 @@ func GetAuthChain(
for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database.
- events, err := fn(ctx, 0, eventsToFetch)
+ events, err := fn(ctx, roomInfo, eventsToFetch)
if err != nil {
return nil, err
}
@@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
}
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
- chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
+ chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs)
if err != nil {
return err
}
@@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
- events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID})
+ events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}