aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/eventutil/types.go24
-rw-r--r--internal/pushgateway/client.go66
-rw-r--r--internal/pushgateway/pushgateway.go62
-rw-r--r--internal/pushrules/action.go102
-rw-r--r--internal/pushrules/action_test.go39
-rw-r--r--internal/pushrules/condition.go49
-rw-r--r--internal/pushrules/default.go23
-rw-r--r--internal/pushrules/default_content.go33
-rw-r--r--internal/pushrules/default_override.go165
-rw-r--r--internal/pushrules/default_underride.go119
-rw-r--r--internal/pushrules/evaluate.go165
-rw-r--r--internal/pushrules/evaluate_test.go189
-rw-r--r--internal/pushrules/pushrules.go71
-rw-r--r--internal/pushrules/util.go125
-rw-r--r--internal/pushrules/util_test.go169
-rw-r--r--internal/pushrules/validate.go85
-rw-r--r--internal/pushrules/validate_test.go163
-rw-r--r--internal/sqlutil/sql.go1
18 files changed, 1649 insertions, 1 deletions
diff --git a/internal/eventutil/types.go b/internal/eventutil/types.go
index 6d119ce6..17861d6c 100644
--- a/internal/eventutil/types.go
+++ b/internal/eventutil/types.go
@@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
// AccountData represents account data sent from the client API server to the
// sync API server
type AccountData struct {
+ RoomID string `json:"room_id"`
+ Type string `json:"type"`
+ ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
+}
+
+type ReadMarkerJSON struct {
+ FullyRead string `json:"m.fully_read"`
+ Read string `json:"m.read"`
+}
+
+// NotificationData contains statistics about notifications, sent from
+// the Push Server to the Sync API server.
+type NotificationData struct {
+ // RoomID identifies the scope of the statistics, together with
+ // MXID (which is encoded in the Kafka key).
RoomID string `json:"room_id"`
- Type string `json:"type"`
+
+ // HighlightCount is the number of unread notifications with the
+ // highlight tweak.
+ UnreadHighlightCount int `json:"unread_highlight_count"`
+
+ // UnreadNotificationCount is the total number of unread
+ // notifications.
+ UnreadNotificationCount int `json:"unread_notification_count"`
}
// ProfileResponse is a struct containing all known user profile data
diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go
new file mode 100644
index 00000000..49907cee
--- /dev/null
+++ b/internal/pushgateway/client.go
@@ -0,0 +1,66 @@
+package pushgateway
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/opentracing/opentracing-go"
+)
+
+type httpClient struct {
+ hc *http.Client
+}
+
+// NewHTTPClient creates a new Push Gateway client.
+func NewHTTPClient(disableTLSValidation bool) Client {
+ hc := &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: &http.Transport{
+ DisableKeepAlives: true,
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: disableTLSValidation,
+ },
+ },
+ }
+ return &httpClient{hc: hc}
+}
+
+func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "Notify")
+ defer span.Finish()
+
+ body, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+ hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return err
+ }
+ hreq.Header.Set("Content-Type", "application/json")
+
+ hresp, err := h.hc.Do(hreq)
+ if err != nil {
+ return err
+ }
+
+ //nolint:errcheck
+ defer hresp.Body.Close()
+
+ if hresp.StatusCode == http.StatusOK {
+ return json.NewDecoder(hresp.Body).Decode(resp)
+ }
+
+ var errorBody struct {
+ Message string `json:"message"`
+ }
+ if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil {
+ return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message)
+ }
+ return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url)
+}
diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go
new file mode 100644
index 00000000..88c326eb
--- /dev/null
+++ b/internal/pushgateway/pushgateway.go
@@ -0,0 +1,62 @@
+package pushgateway
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// A Client is how interactions with a Push Gateway is done.
+type Client interface {
+ // Notify sends a notification to the gateway at the given URL.
+ Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error
+}
+
+type NotifyRequest struct {
+ Notification Notification `json:"notification"` // Required
+}
+
+type NotifyResponse struct {
+ // Rejected is the list of device push keys that were rejected
+ // during the push. The caller should remove the push keys so they
+ // are not used again.
+ Rejected []string `json:"rejected"` // Required
+}
+
+type Notification struct {
+ Content json.RawMessage `json:"content,omitempty"`
+ Counts *Counts `json:"counts,omitempty"`
+ Devices []*Device `json:"devices"` // Required
+ EventID string `json:"event_id,omitempty"`
+ ID string `json:"id,omitempty"` // Deprecated name for EventID.
+ Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest.
+ Prio Prio `json:"prio,omitempty"`
+ RoomAlias string `json:"room_alias,omitempty"`
+ RoomID string `json:"room_id,omitempty"`
+ RoomName string `json:"room_name,omitempty"`
+ Sender string `json:"sender,omitempty"`
+ SenderDisplayName string `json:"sender_display_name,omitempty"`
+ Type string `json:"type,omitempty"`
+ UserIsTarget bool `json:"user_is_target,omitempty"`
+}
+
+type Counts struct {
+ MissedCalls int `json:"missed_calls,omitempty"`
+ Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it.
+}
+
+type Device struct {
+ AppID string `json:"app_id"` // Required
+ Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
+ PushKey string `json:"pushkey"` // Required
+ PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
+ Tweaks map[string]interface{} `json:"tweaks,omitempty"`
+}
+
+type Prio string
+
+const (
+ HighPrio Prio = "high"
+ LowPrio Prio = "low"
+)
diff --git a/internal/pushrules/action.go b/internal/pushrules/action.go
new file mode 100644
index 00000000..c7b8cec8
--- /dev/null
+++ b/internal/pushrules/action.go
@@ -0,0 +1,102 @@
+package pushrules
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+)
+
+// An Action is (part of) an outcome of a rule. There are
+// (unofficially) terminal actions, and modifier actions.
+type Action struct {
+ // Kind is the type of action. Has custom encoding in JSON.
+ Kind ActionKind `json:"-"`
+
+ // Tweak is the property to tweak. Has custom encoding in JSON.
+ Tweak TweakKey `json:"-"`
+
+ // Value is some value interpreted according to Kind and Tweak.
+ Value interface{} `json:"value"`
+}
+
+func (a *Action) MarshalJSON() ([]byte, error) {
+ if a.Tweak == UnknownTweak && a.Value == nil {
+ return json.Marshal(a.Kind)
+ }
+
+ if a.Kind != SetTweakAction {
+ return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind)
+ }
+
+ m := map[string]interface{}{
+ string(a.Kind): a.Tweak,
+ }
+ if a.Value != nil {
+ m["value"] = a.Value
+ }
+
+ return json.Marshal(m)
+}
+
+func (a *Action) UnmarshalJSON(bs []byte) error {
+ if bytes.HasPrefix(bs, []byte("\"")) {
+ return json.Unmarshal(bs, &a.Kind)
+ }
+
+ var raw struct {
+ SetTweak TweakKey `json:"set_tweak"`
+ Value interface{} `json:"value"`
+ }
+ if err := json.Unmarshal(bs, &raw); err != nil {
+ return err
+ }
+ if raw.SetTweak == UnknownTweak {
+ return fmt.Errorf("got unknown action JSON: %s", string(bs))
+ }
+ a.Kind = SetTweakAction
+ a.Tweak = raw.SetTweak
+ if raw.Value != nil {
+ a.Value = raw.Value
+ }
+
+ return nil
+}
+
+// ActionKind is the primary discriminator for actions.
+type ActionKind string
+
+const (
+ UnknownAction ActionKind = ""
+
+ // NotifyAction indicates the clients should show a notification.
+ NotifyAction ActionKind = "notify"
+
+ // DontNotifyAction indicates the clients should not show a notification.
+ DontNotifyAction ActionKind = "dont_notify"
+
+ // CoalesceAction tells the clients to show a notification, and
+ // tells both servers and clients that multiple events can be
+ // coalesced into a single notification. The behaviour is
+ // implementation-specific.
+ CoalesceAction ActionKind = "coalesce"
+
+ // SetTweakAction uses the Tweak and Value fields to add a
+ // tweak. Multiple SetTweakAction can be provided in a rule,
+ // combined with NotifyAction or CoalesceAction.
+ SetTweakAction ActionKind = "set_tweak"
+)
+
+// A TweakKey describes a property to be modified/tweaked for events
+// that match the rule.
+type TweakKey string
+
+const (
+ UnknownTweak TweakKey = ""
+
+ // SoundTweak describes which sound to play. Using "default" means
+ // "enable sound".
+ SoundTweak TweakKey = "sound"
+
+ // HighlightTweak asks the clients to highlight the conversation.
+ HighlightTweak TweakKey = "highlight"
+)
diff --git a/internal/pushrules/action_test.go b/internal/pushrules/action_test.go
new file mode 100644
index 00000000..72db9c99
--- /dev/null
+++ b/internal/pushrules/action_test.go
@@ -0,0 +1,39 @@
+package pushrules
+
+import (
+ "encoding/json"
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestActionJSON(t *testing.T) {
+ tsts := []struct {
+ Want Action
+ }{
+ {Action{Kind: NotifyAction}},
+ {Action{Kind: DontNotifyAction}},
+ {Action{Kind: CoalesceAction}},
+ {Action{Kind: SetTweakAction}},
+
+ {Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}},
+ {Action{Kind: SetTweakAction, Tweak: HighlightTweak}},
+ {Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}},
+ }
+ for _, tst := range tsts {
+ t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) {
+ bs, err := json.Marshal(&tst.Want)
+ if err != nil {
+ t.Fatalf("Marshal failed: %v", err)
+ }
+ var got Action
+ if err := json.Unmarshal(bs, &got); err != nil {
+ t.Fatalf("Unmarshal failed: %v", err)
+ }
+ if diff := cmp.Diff(tst.Want, got); diff != "" {
+ t.Errorf("+got -want:\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go
new file mode 100644
index 00000000..2d9773c0
--- /dev/null
+++ b/internal/pushrules/condition.go
@@ -0,0 +1,49 @@
+package pushrules
+
+// A Condition dictates extra conditions for a matching rules. See
+// ConditionKind.
+type Condition struct {
+ // Kind is the primary discriminator for the condition
+ // type. Required.
+ Kind ConditionKind `json:"kind"`
+
+ // Key indicates the dot-separated path of Event fields to
+ // match. Required for EventMatchCondition and
+ // SenderNotificationPermissionCondition.
+ Key string `json:"key,omitempty"`
+
+ // Pattern indicates the value pattern that must match. Required
+ // for EventMatchCondition.
+ Pattern string `json:"pattern,omitempty"`
+
+ // Is indicates the condition that must be fulfilled. Required for
+ // RoomMemberCountCondition.
+ Is string `json:"is,omitempty"`
+}
+
+// ConditionKind represents a kind of condition.
+//
+// SPEC: Unrecognised conditions MUST NOT match any events,
+// effectively making the push rule disabled.
+type ConditionKind string
+
+const (
+ UnknownCondition ConditionKind = ""
+
+ // EventMatchCondition indicates the condition looks for a key
+ // path and matches a pattern. How paths that don't reference a
+ // simple value match against rules is implementation-specific.
+ EventMatchCondition ConditionKind = "event_match"
+
+ // ContainsDisplayNameCondition indicates the current user's
+ // display name must be found in the content body.
+ ContainsDisplayNameCondition ConditionKind = "contains_display_name"
+
+ // RoomMemberCountCondition matches a simple arithmetic comparison
+ // against the total number of members in a room.
+ RoomMemberCountCondition ConditionKind = "room_member_count"
+
+ // SenderNotificationPermissionCondition compares power level for
+ // the sender in the event's room.
+ SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission"
+)
diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go
new file mode 100644
index 00000000..99698551
--- /dev/null
+++ b/internal/pushrules/default.go
@@ -0,0 +1,23 @@
+package pushrules
+
+import (
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// DefaultAccountRuleSets is the complete set of default push rules
+// for an account.
+func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets {
+ return &AccountRuleSets{
+ Global: *DefaultGlobalRuleSet(localpart, serverName),
+ }
+}
+
+// DefaultGlobalRuleSet returns the default ruleset for a given (fully
+// qualified) MXID.
+func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet {
+ return &RuleSet{
+ Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)),
+ Content: defaultContentRules(localpart),
+ Underride: defaultUnderrideRules,
+ }
+}
diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go
new file mode 100644
index 00000000..158afd18
--- /dev/null
+++ b/internal/pushrules/default_content.go
@@ -0,0 +1,33 @@
+package pushrules
+
+func defaultContentRules(localpart string) []*Rule {
+ return []*Rule{
+ mRuleContainsUserNameDefinition(localpart),
+ }
+}
+
+const (
+ MRuleContainsUserName = ".m.rule.contains_user_name"
+)
+
+func mRuleContainsUserNameDefinition(localpart string) *Rule {
+ return &Rule{
+ RuleID: MRuleContainsUserName,
+ Default: true,
+ Enabled: true,
+ Pattern: localpart,
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "default",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: true,
+ },
+ },
+ }
+}
diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go
new file mode 100644
index 00000000..6f66fd66
--- /dev/null
+++ b/internal/pushrules/default_override.go
@@ -0,0 +1,165 @@
+package pushrules
+
+func defaultOverrideRules(userID string) []*Rule {
+ return []*Rule{
+ &mRuleMasterDefinition,
+ &mRuleSuppressNoticesDefinition,
+ mRuleInviteForMeDefinition(userID),
+ &mRuleMemberEventDefinition,
+ &mRuleContainsDisplayNameDefinition,
+ &mRuleTombstoneDefinition,
+ &mRuleRoomNotifDefinition,
+ }
+}
+
+const (
+ MRuleMaster = ".m.rule.master"
+ MRuleSuppressNotices = ".m.rule.suppress_notices"
+ MRuleInviteForMe = ".m.rule.invite_for_me"
+ MRuleMemberEvent = ".m.rule.member_event"
+ MRuleContainsDisplayName = ".m.rule.contains_display_name"
+ MRuleTombstone = ".m.rule.tombstone"
+ MRuleRoomNotif = ".m.rule.roomnotif"
+)
+
+var (
+ mRuleMasterDefinition = Rule{
+ RuleID: MRuleMaster,
+ Default: true,
+ Enabled: false,
+ Conditions: []*Condition{},
+ Actions: []*Action{{Kind: DontNotifyAction}},
+ }
+ mRuleSuppressNoticesDefinition = Rule{
+ RuleID: MRuleSuppressNotices,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "content.msgtype",
+ Pattern: "m.notice",
+ },
+ },
+ Actions: []*Action{{Kind: DontNotifyAction}},
+ }
+ mRuleMemberEventDefinition = Rule{
+ RuleID: MRuleMemberEvent,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.member",
+ },
+ },
+ Actions: []*Action{{Kind: DontNotifyAction}},
+ }
+ mRuleContainsDisplayNameDefinition = Rule{
+ RuleID: MRuleContainsDisplayName,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}},
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "default",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: true,
+ },
+ },
+ }
+ mRuleTombstoneDefinition = Rule{
+ RuleID: MRuleTombstone,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.tombstone",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "state_key",
+ Pattern: "",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleRoomNotifDefinition = Rule{
+ RuleID: MRuleRoomNotif,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "content.body",
+ Pattern: "@room",
+ },
+ {
+ Kind: SenderNotificationPermissionCondition,
+ Key: "room",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+)
+
+func mRuleInviteForMeDefinition(userID string) *Rule {
+ return &Rule{
+ RuleID: MRuleInviteForMe,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.member",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "content.membership",
+ Pattern: "invite",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "state_key",
+ Pattern: userID,
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "default",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+}
diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go
new file mode 100644
index 00000000..de72bd52
--- /dev/null
+++ b/internal/pushrules/default_underride.go
@@ -0,0 +1,119 @@
+package pushrules
+
+const (
+ MRuleCall = ".m.rule.call"
+ MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one"
+ MRuleRoomOneToOne = ".m.rule.room_one_to_one"
+ MRuleMessage = ".m.rule.message"
+ MRuleEncrypted = ".m.rule.encrypted"
+)
+
+var defaultUnderrideRules = []*Rule{
+ &mRuleCallDefinition,
+ &mRuleEncryptedRoomOneToOneDefinition,
+ &mRuleRoomOneToOneDefinition,
+ &mRuleMessageDefinition,
+ &mRuleEncryptedDefinition,
+}
+
+var (
+ mRuleCallDefinition = Rule{
+ RuleID: MRuleCall,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.call.invite",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "ring",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleEncryptedRoomOneToOneDefinition = Rule{
+ RuleID: MRuleEncryptedRoomOneToOne,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: RoomMemberCountCondition,
+ Is: "2",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.encrypted",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleRoomOneToOneDefinition = Rule{
+ RuleID: MRuleRoomOneToOne,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: RoomMemberCountCondition,
+ Is: "2",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.message",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleMessageDefinition = Rule{
+ RuleID: MRuleMessage,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.message",
+ },
+ },
+ Actions: []*Action{{Kind: NotifyAction}},
+ }
+ mRuleEncryptedDefinition = Rule{
+ RuleID: MRuleEncrypted,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.encrypted",
+ },
+ },
+ Actions: []*Action{{Kind: NotifyAction}},
+ }
+)
diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go
new file mode 100644
index 00000000..df22cb04
--- /dev/null
+++ b/internal/pushrules/evaluate.go
@@ -0,0 +1,165 @@
+package pushrules
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// A RuleSetEvaluator encapsulates context to evaluate an event
+// against a rule set.
+type RuleSetEvaluator struct {
+ ec EvaluationContext
+ ruleSet []kindAndRules
+}
+
+// An EvaluationContext gives a RuleSetEvaluator access to the
+// environment, for rules that require that.
+type EvaluationContext interface {
+ // UserDisplayName returns the current user's display name.
+ UserDisplayName() string
+
+ // RoomMemberCount returns the number of members in the room of
+ // the current event.
+ RoomMemberCount() (int, error)
+
+ // HasPowerLevel returns whether the user has at least the given
+ // power in the room of the current event.
+ HasPowerLevel(userID, levelKey string) (bool, error)
+}
+
+// A kindAndRules is just here to simplify iteration of the (ordered)
+// kinds of rules.
+type kindAndRules struct {
+ Kind Kind
+ Rules []*Rule
+}
+
+// NewRuleSetEvaluator creates a new evaluator for the given rule set.
+func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator {
+ return &RuleSetEvaluator{
+ ec: ec,
+ ruleSet: []kindAndRules{
+ {OverrideKind, ruleSet.Override},
+ {ContentKind, ruleSet.Content},
+ {RoomKind, ruleSet.Room},
+ {SenderKind, ruleSet.Sender},
+ {UnderrideKind, ruleSet.Underride},
+ },
+ }
+}
+
+// MatchEvent returns the first matching rule. Returns nil if there
+// was no match rule.
+func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) {
+ // TODO: server-default rules have lower priority than user rules,
+ // but they are stored together with the user rules. It's a bit
+ // unclear what the specification (11.14.1.4 Predefined rules)
+ // means the ordering should be.
+ //
+ // The most reasonable interpretation is that default overrides
+ // still have lower priority than user content rules, so we
+ // iterate twice.
+ for _, rsat := range rse.ruleSet {
+ for _, defRules := range []bool{false, true} {
+ for _, rule := range rsat.Rules {
+ if rule.Default != defRules {
+ continue
+ }
+ ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
+ if err != nil {
+ return nil, err
+ }
+ if ok {
+ return rule, nil
+ }
+ }
+ }
+ }
+
+ // No matching rule.
+ return nil, nil
+}
+
+func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
+ if !rule.Enabled {
+ return false, nil
+ }
+
+ switch kind {
+ case OverrideKind, UnderrideKind:
+ for _, cond := range rule.Conditions {
+ ok, err := conditionMatches(cond, event, ec)
+ if err != nil {
+ return false, err
+ }
+ if !ok {
+ return false, nil
+ }
+ }
+ return true, nil
+
+ case ContentKind:
+ // TODO: "These configure behaviour for (unencrypted) messages
+ // that match certain patterns." - Does that mean "content.body"?
+ return patternMatches("content.body", rule.Pattern, event)
+
+ case RoomKind:
+ return rule.RuleID == event.RoomID(), nil
+
+ case SenderKind:
+ return rule.RuleID == event.Sender(), nil
+
+ default:
+ return false, nil
+ }
+}
+
+func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
+ switch cond.Kind {
+ case EventMatchCondition:
+ return patternMatches(cond.Key, cond.Pattern, event)
+
+ case ContainsDisplayNameCondition:
+ return patternMatches("content.body", ec.UserDisplayName(), event)
+
+ case RoomMemberCountCondition:
+ cmp, err := parseRoomMemberCountCondition(cond.Is)
+ if err != nil {
+ return false, fmt.Errorf("parsing room_member_count condition: %w", err)
+ }
+ n, err := ec.RoomMemberCount()
+ if err != nil {
+ return false, fmt.Errorf("RoomMemberCount failed: %w", err)
+ }
+ return cmp(n), nil
+
+ case SenderNotificationPermissionCondition:
+ return ec.HasPowerLevel(event.Sender(), cond.Key)
+
+ default:
+ return false, nil
+ }
+}
+
+func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
+ re, err := globToRegexp(pattern)
+ if err != nil {
+ return false, err
+ }
+
+ var eventMap map[string]interface{}
+ if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
+ return false, fmt.Errorf("parsing event: %w", err)
+ }
+ v, err := lookupMapPath(strings.Split(key, "."), eventMap)
+ if err != nil {
+ // An unknown path is a benign error that shouldn't stop rule
+ // processing. It's just a non-match.
+ return false, nil
+ }
+
+ return re.MatchString(fmt.Sprint(v)), nil
+}
diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go
new file mode 100644
index 00000000..50e70336
--- /dev/null
+++ b/internal/pushrules/evaluate_test.go
@@ -0,0 +1,189 @@
+package pushrules
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
+ ev := mustEventFromJSON(t, `{}`)
+ defaultEnabled := &Rule{
+ RuleID: ".default.enabled",
+ Default: true,
+ Enabled: true,
+ }
+ userEnabled := &Rule{
+ RuleID: ".user.enabled",
+ Default: false,
+ Enabled: true,
+ }
+ userEnabled2 := &Rule{
+ RuleID: ".user.enabled.2",
+ Default: false,
+ Enabled: true,
+ }
+ tsts := []struct {
+ Name string
+ RuleSet RuleSet
+ Want *Rule
+ }{
+ {"empty", RuleSet{}, nil},
+ {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled},
+ {"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled},
+ {"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled},
+ {"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled},
+ {"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled},
+ {"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled},
+ {"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ rse := NewRuleSetEvaluator(nil, &tst.RuleSet)
+ got, err := rse.MatchEvent(ev)
+ if err != nil {
+ t.Fatalf("MatchEvent failed: %v", err)
+ }
+ if diff := cmp.Diff(tst.Want, got); diff != "" {
+ t.Errorf("MatchEvent rule: +got -want:\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestRuleMatches(t *testing.T) {
+ emptyRule := Rule{Enabled: true}
+ tsts := []struct {
+ Name string
+ Kind Kind
+ Rule Rule
+ EventJSON string
+ Want bool
+ }{
+ {"emptyOverride", OverrideKind, emptyRule, `{}`, true},
+ {"emptyContent", ContentKind, emptyRule, `{}`, false},
+ {"emptyRoom", RoomKind, emptyRule, `{}`, true},
+ {"emptySender", SenderKind, emptyRule, `{}`, true},
+ {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
+
+ {"disabled", OverrideKind, Rule{}, `{}`, false},
+
+ {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true},
+ {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
+
+ {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
+ {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
+
+ {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
+ {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
+
+ {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
+ {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
+
+ {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
+ {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
+ if err != nil {
+ t.Fatalf("ruleMatches failed: %v", err)
+ }
+ if got != tst.Want {
+ t.Errorf("ruleMatches: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+func TestConditionMatches(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Cond Condition
+ EventJSON string
+ Want bool
+ }{
+ {"empty", Condition{}, `{}`, false},
+ {"empty", Condition{Kind: "unknownstring"}, `{}`, false},
+
+ {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true},
+
+ {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
+ {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
+
+ {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
+ {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
+ {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
+ {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
+ {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
+ {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
+ {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
+ {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
+ {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
+ {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
+
+ {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
+ {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{})
+ if err != nil {
+ t.Fatalf("conditionMatches failed: %v", err)
+ }
+ if got != tst.Want {
+ t.Errorf("conditionMatches: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+type fakeEvaluationContext struct{}
+
+func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
+func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil }
+func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) {
+ return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil
+}
+
+func TestPatternMatches(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Key string
+ Pattern string
+ EventJSON string
+ Want bool
+ }{
+ {"empty", "", "", `{}`, false},
+
+ // Note that an empty pattern contains no wildcard characters,
+ // which implicitly means "*".
+ {"patternEmpty", "content", "", `{"content":{}}`, true},
+
+ {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
+ {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
+ {"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true},
+ {"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true},
+ {"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON))
+ if err != nil {
+ t.Fatalf("patternMatches failed: %v", err)
+ }
+ if got != tst.Want {
+ t.Errorf("patternMatches: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event {
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ev
+}
diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go
new file mode 100644
index 00000000..bbed1f95
--- /dev/null
+++ b/internal/pushrules/pushrules.go
@@ -0,0 +1,71 @@
+package pushrules
+
+// An AccountRuleSets carries the rule sets associated with an
+// account.
+type AccountRuleSets struct {
+ Global RuleSet `json:"global"` // Required
+}
+
+// A RuleSet contains all the various push rules for an
+// account. Listed in decreasing order of priority.
+type RuleSet struct {
+ Override []*Rule `json:"override,omitempty"`
+ Content []*Rule `json:"content,omitempty"`
+ Room []*Rule `json:"room,omitempty"`
+ Sender []*Rule `json:"sender,omitempty"`
+ Underride []*Rule `json:"underride,omitempty"`
+}
+
+// A Rule contains matchers, conditions and final actions. While
+// evaluating, at most one rule is considered matching.
+//
+// Kind and scope are part of the push rules request/responses, but
+// not of the core data model.
+type Rule struct {
+ // RuleID is either a free identifier, or the sender's MXID for
+ // SenderKind. Required.
+ RuleID string `json:"rule_id"`
+
+ // Default indicates whether this is a server-defined default, or
+ // a user-provided rule. Required.
+ //
+ // The server-default rules have the lowest priority.
+ Default bool `json:"default"`
+
+ // Enabled allows the user to disable rules while keeping them
+ // around. Required.
+ Enabled bool `json:"enabled"`
+
+ // Actions describe the desired outcome, should the rule
+ // match. Required.
+ Actions []*Action `json:"actions"`
+
+ // Conditions provide the rule's conditions for OverrideKind and
+ // UnderrideKind. Not allowed for other kinds.
+ Conditions []*Condition `json:"conditions"`
+
+ // Pattern is the body pattern to match for ContentKind. Required
+ // for that kind. The interpretation is the same as that of
+ // Condition.Pattern.
+ Pattern string `json:"pattern"`
+}
+
+// Scope only has one valid value. See also AccountRuleSets.
+type Scope string
+
+const (
+ UnknownScope Scope = ""
+ GlobalScope Scope = "global"
+)
+
+// Kind is the type of push rule. See also RuleSet.
+type Kind string
+
+const (
+ UnknownKind Kind = ""
+ OverrideKind Kind = "override"
+ ContentKind Kind = "content"
+ RoomKind Kind = "room"
+ SenderKind Kind = "sender"
+ UnderrideKind Kind = "underride"
+)
diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go
new file mode 100644
index 00000000..027d35ef
--- /dev/null
+++ b/internal/pushrules/util.go
@@ -0,0 +1,125 @@
+package pushrules
+
+import (
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+)
+
+// ActionsToTweaks converts a list of actions into a primary action
+// kind and a tweaks map. Returns a nil map if it would have been
+// empty.
+func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
+ var kind ActionKind
+ tweaks := map[string]interface{}{}
+
+ for _, a := range as {
+ if a.Kind == SetTweakAction {
+ tweaks[string(a.Tweak)] = a.Value
+ continue
+ }
+ if kind != UnknownAction {
+ return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
+ }
+ kind = a.Kind
+ }
+
+ if len(tweaks) == 0 {
+ tweaks = nil
+ }
+
+ return kind, tweaks, nil
+}
+
+// BoolTweakOr returns the named tweak as a boolean, and returns `def`
+// on failure.
+func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool {
+ v, ok := tweaks[string(key)]
+ if !ok {
+ return def
+ }
+ b, ok := v.(bool)
+ if !ok {
+ return def
+ }
+ return b
+}
+
+// globToRegexp converts a Matrix glob-style pattern to a Regular expression.
+func globToRegexp(pattern string) (*regexp.Regexp, error) {
+ // TODO: It's unclear which glob characters are supported. The only
+ // place this is discussed is for the unrelated "m.policy.rule.*"
+ // events. Assuming, the same: /[*?]/
+ if !strings.ContainsAny(pattern, "*?") {
+ pattern = "*" + pattern + "*"
+ }
+
+ // The defined syntax doesn't allow escaping the glob wildcard
+ // characters, which makes this a straight-forward
+ // replace-after-quote.
+ pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta)
+ pattern = strings.Replace(pattern, "*", ".*", -1)
+ pattern = strings.Replace(pattern, "?", ".", -1)
+ return regexp.Compile("^(" + pattern + ")$")
+}
+
+// globNonMetaRegexp are the characters that are not considered glob
+// meta-characters (i.e. may need escaping).
+var globNonMetaRegexp = regexp.MustCompile("[^*?]+")
+
+// lookupMapPath traverses a hierarchical map structure, like the one
+// produced by json.Unmarshal, to return the leaf value. Traversing
+// arrays/slices is not supported, only objects/maps.
+func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) {
+ if len(path) == 0 {
+ return nil, fmt.Errorf("empty path")
+ }
+
+ var v interface{} = m
+ for i, key := range path {
+ m, ok := v.(map[string]interface{})
+ if !ok {
+ return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v)
+ }
+
+ v, ok = m[key]
+ if !ok {
+ return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], "."))
+ }
+ }
+
+ return v, nil
+}
+
+// parseRoomMemberCountCondition parses a string like "2", "==2", "<2"
+// into a function that checks if the argument to it fulfils the
+// condition.
+func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
+ var b int
+ var cmp = func(a int) bool { return a == b }
+ switch {
+ case strings.HasPrefix(s, "<="):
+ cmp = func(a int) bool { return a <= b }
+ s = s[2:]
+ case strings.HasPrefix(s, ">="):
+ cmp = func(a int) bool { return a >= b }
+ s = s[2:]
+ case strings.HasPrefix(s, "<"):
+ cmp = func(a int) bool { return a < b }
+ s = s[1:]
+ case strings.HasPrefix(s, ">"):
+ cmp = func(a int) bool { return a > b }
+ s = s[1:]
+ case strings.HasPrefix(s, "=="):
+ // Same cmp as the default.
+ s = s[2:]
+ }
+
+ v, err := strconv.ParseInt(s, 10, 64)
+ if err != nil {
+ return nil, err
+ }
+ b = int(v)
+ return cmp, nil
+}
diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go
new file mode 100644
index 00000000..a951c55a
--- /dev/null
+++ b/internal/pushrules/util_test.go
@@ -0,0 +1,169 @@
+package pushrules
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestActionsToTweaks(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Input []*Action
+ WantKind ActionKind
+ WantTweaks map[string]interface{}
+ }{
+ {"empty", nil, UnknownAction, nil},
+ {"zero", []*Action{{}}, UnknownAction, nil},
+ {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
+ {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
+ {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
+ {
+ "all",
+ []*Action{
+ {Kind: CoalesceAction},
+ {Kind: SetTweakAction, Tweak: HighlightTweak},
+ {Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"},
+ },
+ CoalesceAction,
+ map[string]interface{}{"highlight": nil, "sound": "default"},
+ },
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ gotKind, gotTweaks, err := ActionsToTweaks(tst.Input)
+ if err != nil {
+ t.Fatalf("ActionsToTweaks failed: %v", err)
+ }
+ if gotKind != tst.WantKind {
+ t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind)
+ }
+ if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" {
+ t.Errorf("tweaks: +got -want:\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestBoolTweakOr(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Input map[string]interface{}
+ Def bool
+ Want bool
+ }{
+ {"nil", nil, false, false},
+ {"nilValue", map[string]interface{}{"highlight": nil}, true, true},
+ {"false", map[string]interface{}{"highlight": false}, true, false},
+ {"true", map[string]interface{}{"highlight": true}, false, true},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def)
+ if got != tst.Want {
+ t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+func TestGlobToRegexp(t *testing.T) {
+ tsts := []struct {
+ Input string
+ Want string
+ }{
+ {"", "^(.*.*)$"},
+ {"a", "^(.*a.*)$"},
+ {"a.b", "^(.*a\\.b.*)$"},
+ {"a?b", "^(a.b)$"},
+ {"a*b*", "^(a.*b.*)$"},
+ {"a*b?", "^(a.*b.)$"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Want, func(t *testing.T) {
+ got, err := globToRegexp(tst.Input)
+ if err != nil {
+ t.Fatalf("globToRegexp failed: %v", err)
+ }
+ if got.String() != tst.Want {
+ t.Errorf("got %v, want %v", got.String(), tst.Want)
+ }
+ })
+ }
+}
+
+func TestLookupMapPath(t *testing.T) {
+ tsts := []struct {
+ Path []string
+ Root map[string]interface{}
+ Want interface{}
+ }{
+ {[]string{"a"}, map[string]interface{}{"a": "b"}, "b"},
+ {[]string{"a"}, map[string]interface{}{"a": 42}, 42},
+ {[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"},
+ }
+ for _, tst := range tsts {
+ t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) {
+ got, err := lookupMapPath(tst.Path, tst.Root)
+ if err != nil {
+ t.Fatalf("lookupMapPath failed: %v", err)
+ }
+ if diff := cmp.Diff(tst.Want, got); diff != "" {
+ t.Errorf("+got -want:\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLookupMapPathInvalid(t *testing.T) {
+ tsts := []struct {
+ Path []string
+ Root map[string]interface{}
+ }{
+ {nil, nil},
+ {[]string{"a"}, nil},
+ {[]string{"a", "b"}, map[string]interface{}{"a": "c"}},
+ }
+ for _, tst := range tsts {
+ t.Run(fmt.Sprint(tst.Path), func(t *testing.T) {
+ got, err := lookupMapPath(tst.Path, tst.Root)
+ if err == nil {
+ t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got)
+ }
+ })
+ }
+}
+
+func TestParseRoomMemberCountCondition(t *testing.T) {
+ tsts := []struct {
+ Input string
+ WantTrue []int
+ WantFalse []int
+ }{
+ {"1", []int{1}, []int{0, 2}},
+ {"==1", []int{1}, []int{0, 2}},
+ {"<1", []int{0}, []int{1, 2}},
+ {"<=1", []int{0, 1}, []int{2}},
+ {">1", []int{2}, []int{0, 1}},
+ {">=42", []int{42, 43}, []int{41}},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Input, func(t *testing.T) {
+ got, err := parseRoomMemberCountCondition(tst.Input)
+ if err != nil {
+ t.Fatalf("parseRoomMemberCountCondition failed: %v", err)
+ }
+ for _, v := range tst.WantTrue {
+ if !got(v) {
+ t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v)
+ }
+ }
+ for _, v := range tst.WantFalse {
+ if got(v) {
+ t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v)
+ }
+ }
+ })
+ }
+}
diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go
new file mode 100644
index 00000000..5d260f0b
--- /dev/null
+++ b/internal/pushrules/validate.go
@@ -0,0 +1,85 @@
+package pushrules
+
+import (
+ "fmt"
+ "regexp"
+)
+
+// ValidateRule checks the rule for errors. These follow from Sytests
+// and the specification.
+func ValidateRule(kind Kind, rule *Rule) []error {
+ var errs []error
+
+ if !validRuleIDRE.MatchString(rule.RuleID) {
+ errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
+ }
+
+ if len(rule.Actions) == 0 {
+ errs = append(errs, fmt.Errorf("missing actions"))
+ }
+ for _, action := range rule.Actions {
+ errs = append(errs, validateAction(action)...)
+ }
+
+ for _, cond := range rule.Conditions {
+ errs = append(errs, validateCondition(cond)...)
+ }
+
+ switch kind {
+ case OverrideKind, UnderrideKind:
+ // The empty list is allowed, but for JSON-encoding reasons,
+ // it must not be nil.
+ if rule.Conditions == nil {
+ errs = append(errs, fmt.Errorf("missing rule conditions"))
+ }
+
+ case ContentKind:
+ if rule.Pattern == "" {
+ errs = append(errs, fmt.Errorf("missing content rule pattern"))
+ }
+
+ case RoomKind, SenderKind:
+ // Do nothing.
+
+ default:
+ errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind))
+ }
+
+ return errs
+}
+
+// validRuleIDRE is a regexp for valid IDs.
+//
+// TODO: the specification doesn't seem to say what the rule ID syntax
+// is. A Sytest fails if it contains a backslash.
+var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`)
+
+// validateAction returns issues with an Action.
+func validateAction(action *Action) []error {
+ var errs []error
+
+ switch action.Kind {
+ case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction:
+ // Do nothing.
+
+ default:
+ errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind))
+ }
+
+ return errs
+}
+
+// validateCondition returns issues with a Condition.
+func validateCondition(cond *Condition) []error {
+ var errs []error
+
+ switch cond.Kind {
+ case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition:
+ // Do nothing.
+
+ default:
+ errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind))
+ }
+
+ return errs
+}
diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go
new file mode 100644
index 00000000..b276eb55
--- /dev/null
+++ b/internal/pushrules/validate_test.go
@@ -0,0 +1,163 @@
+package pushrules
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestValidateRuleNegatives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Kind Kind
+ Rule Rule
+ WantErrString string
+ }{
+ {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
+ {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
+ {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
+ {"noActions", OverrideKind, Rule{}, "missing actions"},
+ {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
+ {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
+ {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
+ {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
+ {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := ValidateRule(tst.Kind, &tst.Rule)
+ var foundErr error
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantErrString) {
+ foundErr = err
+ }
+ }
+ if foundErr == nil {
+ t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
+ }
+ })
+ }
+}
+
+func TestValidateRulePositives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Kind Kind
+ Rule Rule
+ WantNoErrString string
+ }{
+ {"invalidKind", OverrideKind, Rule{}, "invalid rule kind"},
+ {"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"},
+ {"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"},
+ {"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"},
+ {"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"},
+ {"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"},
+ {"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"},
+ {"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := ValidateRule(tst.Kind, &tst.Rule)
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantNoErrString) {
+ t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
+ }
+ }
+ })
+ }
+}
+
+func TestValidateActionNegatives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Action Action
+ WantErrString string
+ }{
+ {"emptyKind", Action{}, "invalid rule action kind"},
+ {"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateAction(&tst.Action)
+ var foundErr error
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantErrString) {
+ foundErr = err
+ }
+ }
+ if foundErr == nil {
+ t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
+ }
+ })
+ }
+}
+
+func TestValidateActionPositives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Action Action
+ WantNoErrString string
+ }{
+ {"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateAction(&tst.Action)
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantNoErrString) {
+ t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
+ }
+ }
+ })
+ }
+}
+
+func TestValidateConditionNegatives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Condition Condition
+ WantErrString string
+ }{
+ {"emptyKind", Condition{}, "invalid rule condition kind"},
+ {"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateCondition(&tst.Condition)
+ var foundErr error
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantErrString) {
+ foundErr = err
+ }
+ }
+ if foundErr == nil {
+ t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
+ }
+ })
+ }
+}
+
+func TestValidateConditionPositives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Condition Condition
+ WantNoErrString string
+ }{
+ {"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateCondition(&tst.Condition)
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantNoErrString) {
+ t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
+ }
+ }
+ })
+ }
+}
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go
index 8d0d2dfa..19483b26 100644
--- a/internal/sqlutil/sql.go
+++ b/internal/sqlutil/sql.go
@@ -163,6 +163,7 @@ type StatementList []struct {
func (s StatementList) Prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
+ err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL)
return
}
}