aboutsummaryrefslogtreecommitdiff
path: root/syncapi/routing/relations.go
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/routing/relations.go')
-rw-r--r--syncapi/routing/relations.go33
1 files changed, 22 insertions, 11 deletions
diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go
index 17933b2f..e3d1069a 100644
--- a/syncapi/routing/relations.go
+++ b/syncapi/routing/relations.go
@@ -43,9 +43,25 @@ func Relations(
req *http.Request, device *userapi.Device,
syncDB storage.Database,
rsAPI api.SyncRoomserverAPI,
- roomID, eventID, relType, eventType string,
+ rawRoomID, eventID, relType, eventType string,
) util.JSONResponse {
- var err error
+ roomID, err := spec.NewRoomID(rawRoomID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("invalid room ID"),
+ }
+ }
+
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.Unknown("internal server error"),
+ }
+ }
+
var from, to types.StreamPosition
var limit int
dir := req.URL.Query().Get("dir")
@@ -93,7 +109,7 @@ func Relations(
}
var events []types.StreamEvent
events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor(
- req.Context(), roomID, eventID, relType, eventType, from, to, dir == "b", limit,
+ req.Context(), roomID.String(), eventID, relType, eventType, from, to, dir == "b", limit,
)
if err != nil {
return util.ErrorResponse(err)
@@ -105,12 +121,7 @@ func Relations(
}
// Apply history visibility to the result events.
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, device.UserID, "relations")
- if err != nil {
- return util.ErrorResponse(err)
- }
-
- validRoomID, err := spec.NewRoomID(roomID)
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, *userID, "relations")
if err != nil {
return util.ErrorResponse(err)
}
@@ -120,14 +131,14 @@ func Relations(
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
+ userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
- skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
+ skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString