aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-04-14 13:35:27 +0200
committerGitHub <noreply@github.com>2023-04-14 12:35:27 +0100
commitc45d8cd68875f7f5081f252cfdc2dd32f99c58f8 (patch)
treeee3d2fd0df11fc7bebc0d57ffa6040c5fb2bd203
parentca63b414da87f7bdb25effffd187d51191a42b3e (diff)
Add pushrules tests (#3044)
partly takes care of https://github.com/matrix-org/dendrite/issues/2870 by making sure that rule IDs don't start with a dot. Co-authored-by: kegsay <kegan@matrix.org>
-rw-r--r--clientapi/clientapi_test.go395
-rw-r--r--clientapi/routing/pushrules.go55
-rw-r--r--internal/pushrules/validate.go4
-rw-r--r--userapi/api/api.go17
-rw-r--r--userapi/internal/user_api.go27
5 files changed, 430 insertions, 68 deletions
diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go
index 76295ba5..0be26273 100644
--- a/clientapi/clientapi_test.go
+++ b/clientapi/clientapi_test.go
@@ -8,10 +8,12 @@ import (
"io"
"net/http"
"net/http/httptest"
+ "net/url"
"strings"
"testing"
"time"
+ "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -1235,3 +1237,396 @@ func Test3PID(t *testing.T) {
}
})
}
+
+func TestPushRules(t *testing.T) {
+ alice := test.NewUser(t)
+
+ // create the default push rules, used when validating responses
+ localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
+ defaultRules, err := json.Marshal(pushRuleSets)
+ assert.NoError(t, err)
+
+ ruleID1 := "myrule"
+ ruleID2 := "myrule2"
+ ruleID3 := "myrule3"
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ cfg.ClientAPI.RateLimiting.Enabled = false
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ natsInstance := jetstream.NATSInstance{}
+ defer close()
+
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
+
+ // We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+
+ accessTokens := map[*test.User]userDevice{
+ alice: {},
+ }
+ createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
+
+ testCases := []struct {
+ name string
+ request *http.Request
+ wantStatusCode int
+ validateFunc func(t *testing.T, respBody *bytes.Buffer) // used when updating rules, otherwise wantStatusCode should be enough
+ queryAttr map[string]string
+ }{
+ {
+ name: "can not get rules without trailing slash",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can get default rules",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ assert.Equal(t, defaultRules, respBody.Bytes())
+ },
+ },
+ {
+ name: "can get rules by scope",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ assert.Equal(t, gjson.GetBytes(defaultRules, "global").Raw, respBody.String())
+ },
+ },
+ {
+ name: "can not get invalid rules by scope",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not get rules for invalid scope and kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/invalid/", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not get rules for invalid kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can get rules by scope and kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ assert.Equal(t, gjson.GetBytes(defaultRules, "global.override").Raw, respBody.String())
+ },
+ },
+ {
+ name: "can get rules by scope and content kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ assert.Equal(t, gjson.GetBytes(defaultRules, "global.content").Raw, respBody.String())
+ },
+ },
+ {
+ name: "can not get rules by scope and room kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/room/", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not get rules by scope and sender kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/sender/", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can get rules by scope and underride kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/underride/", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ assert.Equal(t, gjson.GetBytes(defaultRules, "global.underride").Raw, respBody.String())
+ },
+ },
+ {
+ name: "can not get rules by scope, kind and ID for invalid scope",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/doesnotexist/.m.rule.master", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not get rules by scope, kind and ID for invalid kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/doesnotexist/.m.rule.master", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can get rules by scope, kind and ID",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ },
+ {
+ name: "can not get rules by scope, kind and ID for invalid ID",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.doesnotexist", strings.NewReader("")),
+ wantStatusCode: http.StatusNotFound,
+ },
+ {
+ name: "can not get status for invalid attribute",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/invalid", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not get status for invalid kind",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not get enabled status for invalid scope",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not get enabled status for invalid rule",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/doesnotexist/enabled", strings.NewReader("")),
+ wantStatusCode: http.StatusNotFound,
+ },
+ {
+ name: "can get enabled rules by scope, kind and ID",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ assert.False(t, gjson.GetBytes(respBody.Bytes(), "enabled").Bool(), "expected master rule to be disabled")
+ },
+ },
+ {
+ name: "can get actions scope, kind and ID",
+ request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ actions := gjson.GetBytes(respBody.Bytes(), "actions").Array()
+ // only a basic check
+ assert.Equal(t, 1, len(actions))
+ },
+ },
+ {
+ name: "can not set enabled status with invalid JSON",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not set attribute for invalid attribute",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/doesnotexist", strings.NewReader("{}")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not set attribute for invalid scope",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("{}")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not set attribute for invalid kind",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("{}")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not set attribute for invalid rule",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/invalid/enabled", strings.NewReader("{}")),
+ wantStatusCode: http.StatusNotFound,
+ },
+ {
+ name: "can set enabled status with valid JSON",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(`{"enabled":true}`)),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(""))
+ req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+ routers.Client.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
+ assert.True(t, gjson.GetBytes(rec.Body.Bytes(), "enabled").Bool(), "expected master rule to be enabled: %s", rec.Body.String())
+ },
+ },
+ {
+ name: "can set actions with valid JSON",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(`{"actions":["dont_notify","notify"]}`)),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(""))
+ req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+ routers.Client.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
+ assert.Equal(t, 2, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 2 actions %s", rec.Body.String())
+ },
+ },
+ {
+ name: "can not create new push rule with invalid JSON",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not create new push rule with invalid rule content",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("{}")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not create new push rule with invalid scope",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can create new push rule with valid rule content",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule/actions", strings.NewReader(""))
+ req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+ routers.Client.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
+ assert.Equal(t, 1, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 1 action %s", rec.Body.String())
+ },
+ },
+ {
+ name: "can not create new push starting with a dot",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/.myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can create new push rule after existing",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
+ queryAttr: map[string]string{
+ "after": ruleID1,
+ },
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
+ req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+ routers.Client.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
+ rules := gjson.ParseBytes(rec.Body.Bytes())
+ for i, rule := range rules.Array() {
+ if rule.Get("rule_id").Str == ruleID1 && i != 0 {
+ t.Fatalf("expected '%s' to be the first, but wasn't", ruleID1)
+ }
+ if rule.Get("rule_id").Str == ruleID2 && i != 1 {
+ t.Fatalf("expected '%s' to be the second, but wasn't", ruleID2)
+ }
+ }
+ },
+ },
+ {
+ name: "can create new push rule before existing",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule3", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
+ queryAttr: map[string]string{
+ "before": ruleID1,
+ },
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
+ req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+ routers.Client.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
+ rules := gjson.ParseBytes(rec.Body.Bytes())
+ for i, rule := range rules.Array() {
+ if rule.Get("rule_id").Str == ruleID3 && i != 0 {
+ t.Fatalf("expected '%s' to be the first, but wasn't", ruleID3)
+ }
+ if rule.Get("rule_id").Str == ruleID1 && i != 1 {
+ t.Fatalf("expected '%s' to be the second, but wasn't", ruleID1)
+ }
+ if rule.Get("rule_id").Str == ruleID2 && i != 2 {
+ t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
+ }
+ }
+ },
+ },
+ {
+ name: "can modify existing push rule",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule2/actions", strings.NewReader(""))
+ req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+ routers.Client.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
+ actions := gjson.GetBytes(rec.Body.Bytes(), "actions").Array()
+ // there should only be one action
+ assert.Equal(t, "dont_notify", actions[0].Str)
+ },
+ },
+ {
+ name: "can move existing push rule to the front",
+ request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
+ queryAttr: map[string]string{
+ "before": ruleID3,
+ },
+ wantStatusCode: http.StatusOK,
+ validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
+ req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+ routers.Client.ServeHTTP(rec, req)
+ assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
+ rules := gjson.ParseBytes(rec.Body.Bytes())
+ for i, rule := range rules.Array() {
+ if rule.Get("rule_id").Str == ruleID2 && i != 0 {
+ t.Fatalf("expected '%s' to be the first, but wasn't", ruleID2)
+ }
+ if rule.Get("rule_id").Str == ruleID3 && i != 1 {
+ t.Fatalf("expected '%s' to be the second, but wasn't", ruleID3)
+ }
+ if rule.Get("rule_id").Str == ruleID1 && i != 2 {
+ t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
+ }
+ }
+ },
+ },
+ {
+ name: "can not delete push rule with invalid scope",
+ request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/invalid/content/myrule2", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not delete push rule with invalid kind",
+ request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/invalid/myrule2", strings.NewReader("")),
+ wantStatusCode: http.StatusBadRequest,
+ },
+ {
+ name: "can not delete push rule with non-existent rule",
+ request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/doesnotexist", strings.NewReader("")),
+ wantStatusCode: http.StatusNotFound,
+ },
+ {
+ name: "can delete existing push rule",
+ request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader("")),
+ wantStatusCode: http.StatusOK,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ rec := httptest.NewRecorder()
+
+ if tc.queryAttr != nil {
+ params := url.Values{}
+ for k, v := range tc.queryAttr {
+ params.Set(k, v)
+ }
+
+ tc.request = httptest.NewRequest(tc.request.Method, tc.request.URL.String()+"?"+params.Encode(), tc.request.Body)
+ }
+
+ tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
+
+ routers.Client.ServeHTTP(rec, tc.request)
+ assert.Equal(t, tc.wantStatusCode, rec.Code, rec.Body.String())
+ if tc.validateFunc != nil {
+ tc.validateFunc(t, rec.Body)
+ }
+ t.Logf("%s", rec.Body.String())
+ })
+ }
+ })
+}
diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go
index 856f52c7..f1a539ad 100644
--- a/clientapi/routing/pushrules.go
+++ b/clientapi/routing/pushrules.go
@@ -31,7 +31,7 @@ func errorResponse(ctx context.Context, err error, msg string, args ...interface
}
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed")
}
@@ -42,7 +42,7 @@ func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userap
}
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed")
}
@@ -57,7 +57,7 @@ func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Devi
}
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
@@ -66,7 +66,8 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
- if rulesPtr == nil {
+ // Even if rulesPtr is not nil, there may not be any rules for this kind
+ if rulesPtr == nil || (rulesPtr != nil && len(*rulesPtr) == 0) {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
return util.JSONResponse{
@@ -76,7 +77,7 @@ func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi
}
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
@@ -101,7 +102,10 @@ func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
var newRule pushrules.Rule
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
- return errorResponse(ctx, err, "JSON Decode failed")
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.BadJSON(err.Error()),
+ }
}
newRule.RuleID = ruleID
@@ -110,7 +114,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
}
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
@@ -120,6 +124,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
+ // while this should be impossible (ValidateRule would already return an error), better keep it around
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
@@ -144,7 +149,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
}
// Add new rule.
- i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
+ i, err = findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
if err != nil {
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
}
@@ -153,7 +158,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
}
- if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
+ if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
@@ -161,7 +166,7 @@ func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID,
}
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
@@ -180,7 +185,7 @@ func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, dev
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
- if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
+ if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
@@ -192,7 +197,7 @@ func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
if err != nil {
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
}
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
@@ -238,7 +243,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
}
- ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ ruleSets, err := userAPI.QueryPushRules(ctx, device.UserID)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
@@ -258,7 +263,7 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
attrSet((*rulesPtr)[i], &newPartialRule)
- if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
+ if err = userAPI.PerformPushRulesPut(ctx, device.UserID, ruleSets); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
}
@@ -266,28 +271,6 @@ func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr stri
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}
-func queryPushRules(ctx context.Context, userID string, userAPI userapi.ClientUserAPI) (*pushrules.AccountRuleSets, error) {
- var res userapi.QueryPushRulesResponse
- if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
- util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
- return nil, err
- }
- return res.RuleSets, nil
-}
-
-func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.ClientUserAPI) error {
- req := userapi.PerformPushRulesPutRequest{
- UserID: userID,
- RuleSets: ruleSets,
- }
- var res struct{}
- if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
- util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
- return err
- }
- return nil
-}
-
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
switch scope {
case pushrules.GlobalScope:
diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go
index f50c51bd..b54ec3fb 100644
--- a/internal/pushrules/validate.go
+++ b/internal/pushrules/validate.go
@@ -10,6 +10,10 @@ import (
func ValidateRule(kind Kind, rule *Rule) []error {
var errs []error
+ if len(rule.RuleID) > 0 && rule.RuleID[:1] == "." {
+ errs = append(errs, fmt.Errorf("invalid rule ID: rule can not start with a dot"))
+ }
+
if !validRuleIDRE.MatchString(rule.RuleID) {
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
}
diff --git a/userapi/api/api.go b/userapi/api/api.go
index ba1c374f..7c47efd2 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -90,7 +90,7 @@ type ClientUserAPI interface {
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
- QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
+ QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
@@ -99,7 +99,7 @@ type ClientUserAPI interface {
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
- PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
+ PerformPushRulesPut(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
@@ -555,19 +555,6 @@ const (
HTTPKind PusherKind = "http"
)
-type PerformPushRulesPutRequest struct {
- UserID string `json:"user_id"`
- RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
-}
-
-type QueryPushRulesRequest struct {
- UserID string `json:"user_id"`
-}
-
-type QueryPushRulesResponse struct {
- RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
-}
-
type QueryNotificationsRequest struct {
Localpart string `json:"localpart"` // Required.
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go
index 6dad91dd..139ca758 100644
--- a/userapi/internal/user_api.go
+++ b/userapi/internal/user_api.go
@@ -26,6 +26,7 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
+ "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@@ -872,36 +873,28 @@ func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPusher
func (a *UserInternalAPI) PerformPushRulesPut(
ctx context.Context,
- req *api.PerformPushRulesPutRequest,
- _ *struct{},
+ userID string,
+ ruleSets *pushrules.AccountRuleSets,
) error {
- bs, err := json.Marshal(&req.RuleSets)
+ bs, err := json.Marshal(ruleSets)
if err != nil {
return err
}
userReq := api.InputAccountDataRequest{
- UserID: req.UserID,
+ UserID: userID,
DataType: pushRulesAccountDataType,
AccountData: json.RawMessage(bs),
}
var userRes api.InputAccountDataResponse // empty
- if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
- return err
- }
- return nil
+ return a.InputAccountData(ctx, &userReq, &userRes)
}
-func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
- localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
- if err != nil {
- return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
- }
- pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
+func (a *UserInternalAPI) QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) {
+ localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
- return fmt.Errorf("failed to query push rules: %w", err)
+ return nil, fmt.Errorf("failed to split user ID %q for push rules", userID)
}
- res.RuleSets = pushRules
- return nil
+ return a.DB.QueryPushRules(ctx, localpart, domain)
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) {