diff options
Diffstat (limited to 'internal')
-rw-r--r-- | internal/eventutil/types.go | 24 | ||||
-rw-r--r-- | internal/pushgateway/client.go | 66 | ||||
-rw-r--r-- | internal/pushgateway/pushgateway.go | 62 | ||||
-rw-r--r-- | internal/pushrules/action.go | 102 | ||||
-rw-r--r-- | internal/pushrules/action_test.go | 39 | ||||
-rw-r--r-- | internal/pushrules/condition.go | 49 | ||||
-rw-r--r-- | internal/pushrules/default.go | 23 | ||||
-rw-r--r-- | internal/pushrules/default_content.go | 33 | ||||
-rw-r--r-- | internal/pushrules/default_override.go | 165 | ||||
-rw-r--r-- | internal/pushrules/default_underride.go | 119 | ||||
-rw-r--r-- | internal/pushrules/evaluate.go | 165 | ||||
-rw-r--r-- | internal/pushrules/evaluate_test.go | 189 | ||||
-rw-r--r-- | internal/pushrules/pushrules.go | 71 | ||||
-rw-r--r-- | internal/pushrules/util.go | 125 | ||||
-rw-r--r-- | internal/pushrules/util_test.go | 169 | ||||
-rw-r--r-- | internal/pushrules/validate.go | 85 | ||||
-rw-r--r-- | internal/pushrules/validate_test.go | 163 | ||||
-rw-r--r-- | internal/sqlutil/sql.go | 1 |
18 files changed, 1649 insertions, 1 deletions
diff --git a/internal/eventutil/types.go b/internal/eventutil/types.go index 6d119ce6..17861d6c 100644 --- a/internal/eventutil/types.go +++ b/internal/eventutil/types.go @@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID") // AccountData represents account data sent from the client API server to the // sync API server type AccountData struct { + RoomID string `json:"room_id"` + Type string `json:"type"` + ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional +} + +type ReadMarkerJSON struct { + FullyRead string `json:"m.fully_read"` + Read string `json:"m.read"` +} + +// NotificationData contains statistics about notifications, sent from +// the Push Server to the Sync API server. +type NotificationData struct { + // RoomID identifies the scope of the statistics, together with + // MXID (which is encoded in the Kafka key). RoomID string `json:"room_id"` - Type string `json:"type"` + + // HighlightCount is the number of unread notifications with the + // highlight tweak. + UnreadHighlightCount int `json:"unread_highlight_count"` + + // UnreadNotificationCount is the total number of unread + // notifications. + UnreadNotificationCount int `json:"unread_notification_count"` } // ProfileResponse is a struct containing all known user profile data diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go new file mode 100644 index 00000000..49907cee --- /dev/null +++ b/internal/pushgateway/client.go @@ -0,0 +1,66 @@ +package pushgateway + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/opentracing/opentracing-go" +) + +type httpClient struct { + hc *http.Client +} + +// NewHTTPClient creates a new Push Gateway client. +func NewHTTPClient(disableTLSValidation bool) Client { + hc := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: disableTLSValidation, + }, + }, + } + return &httpClient{hc: hc} +} + +func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "Notify") + defer span.Finish() + + body, err := json.Marshal(req) + if err != nil { + return err + } + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + hreq.Header.Set("Content-Type", "application/json") + + hresp, err := h.hc.Do(hreq) + if err != nil { + return err + } + + //nolint:errcheck + defer hresp.Body.Close() + + if hresp.StatusCode == http.StatusOK { + return json.NewDecoder(hresp.Body).Decode(resp) + } + + var errorBody struct { + Message string `json:"message"` + } + if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil { + return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message) + } + return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url) +} diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go new file mode 100644 index 00000000..88c326eb --- /dev/null +++ b/internal/pushgateway/pushgateway.go @@ -0,0 +1,62 @@ +package pushgateway + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A Client is how interactions with a Push Gateway is done. +type Client interface { + // Notify sends a notification to the gateway at the given URL. + Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error +} + +type NotifyRequest struct { + Notification Notification `json:"notification"` // Required +} + +type NotifyResponse struct { + // Rejected is the list of device push keys that were rejected + // during the push. The caller should remove the push keys so they + // are not used again. + Rejected []string `json:"rejected"` // Required +} + +type Notification struct { + Content json.RawMessage `json:"content,omitempty"` + Counts *Counts `json:"counts,omitempty"` + Devices []*Device `json:"devices"` // Required + EventID string `json:"event_id,omitempty"` + ID string `json:"id,omitempty"` // Deprecated name for EventID. + Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest. + Prio Prio `json:"prio,omitempty"` + RoomAlias string `json:"room_alias,omitempty"` + RoomID string `json:"room_id,omitempty"` + RoomName string `json:"room_name,omitempty"` + Sender string `json:"sender,omitempty"` + SenderDisplayName string `json:"sender_display_name,omitempty"` + Type string `json:"type,omitempty"` + UserIsTarget bool `json:"user_is_target,omitempty"` +} + +type Counts struct { + MissedCalls int `json:"missed_calls,omitempty"` + Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it. +} + +type Device struct { + AppID string `json:"app_id"` // Required + Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys. + PushKey string `json:"pushkey"` // Required + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Tweaks map[string]interface{} `json:"tweaks,omitempty"` +} + +type Prio string + +const ( + HighPrio Prio = "high" + LowPrio Prio = "low" +) diff --git a/internal/pushrules/action.go b/internal/pushrules/action.go new file mode 100644 index 00000000..c7b8cec8 --- /dev/null +++ b/internal/pushrules/action.go @@ -0,0 +1,102 @@ +package pushrules + +import ( + "bytes" + "encoding/json" + "fmt" +) + +// An Action is (part of) an outcome of a rule. There are +// (unofficially) terminal actions, and modifier actions. +type Action struct { + // Kind is the type of action. Has custom encoding in JSON. + Kind ActionKind `json:"-"` + + // Tweak is the property to tweak. Has custom encoding in JSON. + Tweak TweakKey `json:"-"` + + // Value is some value interpreted according to Kind and Tweak. + Value interface{} `json:"value"` +} + +func (a *Action) MarshalJSON() ([]byte, error) { + if a.Tweak == UnknownTweak && a.Value == nil { + return json.Marshal(a.Kind) + } + + if a.Kind != SetTweakAction { + return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind) + } + + m := map[string]interface{}{ + string(a.Kind): a.Tweak, + } + if a.Value != nil { + m["value"] = a.Value + } + + return json.Marshal(m) +} + +func (a *Action) UnmarshalJSON(bs []byte) error { + if bytes.HasPrefix(bs, []byte("\"")) { + return json.Unmarshal(bs, &a.Kind) + } + + var raw struct { + SetTweak TweakKey `json:"set_tweak"` + Value interface{} `json:"value"` + } + if err := json.Unmarshal(bs, &raw); err != nil { + return err + } + if raw.SetTweak == UnknownTweak { + return fmt.Errorf("got unknown action JSON: %s", string(bs)) + } + a.Kind = SetTweakAction + a.Tweak = raw.SetTweak + if raw.Value != nil { + a.Value = raw.Value + } + + return nil +} + +// ActionKind is the primary discriminator for actions. +type ActionKind string + +const ( + UnknownAction ActionKind = "" + + // NotifyAction indicates the clients should show a notification. + NotifyAction ActionKind = "notify" + + // DontNotifyAction indicates the clients should not show a notification. + DontNotifyAction ActionKind = "dont_notify" + + // CoalesceAction tells the clients to show a notification, and + // tells both servers and clients that multiple events can be + // coalesced into a single notification. The behaviour is + // implementation-specific. + CoalesceAction ActionKind = "coalesce" + + // SetTweakAction uses the Tweak and Value fields to add a + // tweak. Multiple SetTweakAction can be provided in a rule, + // combined with NotifyAction or CoalesceAction. + SetTweakAction ActionKind = "set_tweak" +) + +// A TweakKey describes a property to be modified/tweaked for events +// that match the rule. +type TweakKey string + +const ( + UnknownTweak TweakKey = "" + + // SoundTweak describes which sound to play. Using "default" means + // "enable sound". + SoundTweak TweakKey = "sound" + + // HighlightTweak asks the clients to highlight the conversation. + HighlightTweak TweakKey = "highlight" +) diff --git a/internal/pushrules/action_test.go b/internal/pushrules/action_test.go new file mode 100644 index 00000000..72db9c99 --- /dev/null +++ b/internal/pushrules/action_test.go @@ -0,0 +1,39 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionJSON(t *testing.T) { + tsts := []struct { + Want Action + }{ + {Action{Kind: NotifyAction}}, + {Action{Kind: DontNotifyAction}}, + {Action{Kind: CoalesceAction}}, + {Action{Kind: SetTweakAction}}, + + {Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak}}, + {Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) { + bs, err := json.Marshal(&tst.Want) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + var got Action + if err := json.Unmarshal(bs, &got); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go new file mode 100644 index 00000000..2d9773c0 --- /dev/null +++ b/internal/pushrules/condition.go @@ -0,0 +1,49 @@ +package pushrules + +// A Condition dictates extra conditions for a matching rules. See +// ConditionKind. +type Condition struct { + // Kind is the primary discriminator for the condition + // type. Required. + Kind ConditionKind `json:"kind"` + + // Key indicates the dot-separated path of Event fields to + // match. Required for EventMatchCondition and + // SenderNotificationPermissionCondition. + Key string `json:"key,omitempty"` + + // Pattern indicates the value pattern that must match. Required + // for EventMatchCondition. + Pattern string `json:"pattern,omitempty"` + + // Is indicates the condition that must be fulfilled. Required for + // RoomMemberCountCondition. + Is string `json:"is,omitempty"` +} + +// ConditionKind represents a kind of condition. +// +// SPEC: Unrecognised conditions MUST NOT match any events, +// effectively making the push rule disabled. +type ConditionKind string + +const ( + UnknownCondition ConditionKind = "" + + // EventMatchCondition indicates the condition looks for a key + // path and matches a pattern. How paths that don't reference a + // simple value match against rules is implementation-specific. + EventMatchCondition ConditionKind = "event_match" + + // ContainsDisplayNameCondition indicates the current user's + // display name must be found in the content body. + ContainsDisplayNameCondition ConditionKind = "contains_display_name" + + // RoomMemberCountCondition matches a simple arithmetic comparison + // against the total number of members in a room. + RoomMemberCountCondition ConditionKind = "room_member_count" + + // SenderNotificationPermissionCondition compares power level for + // the sender in the event's room. + SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission" +) diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go new file mode 100644 index 00000000..99698551 --- /dev/null +++ b/internal/pushrules/default.go @@ -0,0 +1,23 @@ +package pushrules + +import ( + "github.com/matrix-org/gomatrixserverlib" +) + +// DefaultAccountRuleSets is the complete set of default push rules +// for an account. +func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets { + return &AccountRuleSets{ + Global: *DefaultGlobalRuleSet(localpart, serverName), + } +} + +// DefaultGlobalRuleSet returns the default ruleset for a given (fully +// qualified) MXID. +func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet { + return &RuleSet{ + Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)), + Content: defaultContentRules(localpart), + Underride: defaultUnderrideRules, + } +} diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go new file mode 100644 index 00000000..158afd18 --- /dev/null +++ b/internal/pushrules/default_content.go @@ -0,0 +1,33 @@ +package pushrules + +func defaultContentRules(localpart string) []*Rule { + return []*Rule{ + mRuleContainsUserNameDefinition(localpart), + } +} + +const ( + MRuleContainsUserName = ".m.rule.contains_user_name" +) + +func mRuleContainsUserNameDefinition(localpart string) *Rule { + return &Rule{ + RuleID: MRuleContainsUserName, + Default: true, + Enabled: true, + Pattern: localpart, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } +} diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go new file mode 100644 index 00000000..6f66fd66 --- /dev/null +++ b/internal/pushrules/default_override.go @@ -0,0 +1,165 @@ +package pushrules + +func defaultOverrideRules(userID string) []*Rule { + return []*Rule{ + &mRuleMasterDefinition, + &mRuleSuppressNoticesDefinition, + mRuleInviteForMeDefinition(userID), + &mRuleMemberEventDefinition, + &mRuleContainsDisplayNameDefinition, + &mRuleTombstoneDefinition, + &mRuleRoomNotifDefinition, + } +} + +const ( + MRuleMaster = ".m.rule.master" + MRuleSuppressNotices = ".m.rule.suppress_notices" + MRuleInviteForMe = ".m.rule.invite_for_me" + MRuleMemberEvent = ".m.rule.member_event" + MRuleContainsDisplayName = ".m.rule.contains_display_name" + MRuleTombstone = ".m.rule.tombstone" + MRuleRoomNotif = ".m.rule.roomnotif" +) + +var ( + mRuleMasterDefinition = Rule{ + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Conditions: []*Condition{}, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleSuppressNoticesDefinition = Rule{ + RuleID: MRuleSuppressNotices, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.msgtype", + Pattern: "m.notice", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleMemberEventDefinition = Rule{ + RuleID: MRuleMemberEvent, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + }, + Actions: []*Action{{Kind: DontNotifyAction}}, + } + mRuleContainsDisplayNameDefinition = Rule{ + RuleID: MRuleContainsDisplayName, + Default: true, + Enabled: true, + Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}}, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: true, + }, + }, + } + mRuleTombstoneDefinition = Rule{ + RuleID: MRuleTombstone, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.tombstone", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: "", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomNotifDefinition = Rule{ + RuleID: MRuleRoomNotif, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "content.body", + Pattern: "@room", + }, + { + Kind: SenderNotificationPermissionCondition, + Key: "room", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +) + +func mRuleInviteForMeDefinition(userID string) *Rule { + return &Rule{ + RuleID: MRuleInviteForMe, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.member", + }, + { + Kind: EventMatchCondition, + Key: "content.membership", + Pattern: "invite", + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: userID, + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "default", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go new file mode 100644 index 00000000..de72bd52 --- /dev/null +++ b/internal/pushrules/default_underride.go @@ -0,0 +1,119 @@ +package pushrules + +const ( + MRuleCall = ".m.rule.call" + MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one" + MRuleRoomOneToOne = ".m.rule.room_one_to_one" + MRuleMessage = ".m.rule.message" + MRuleEncrypted = ".m.rule.encrypted" +) + +var defaultUnderrideRules = []*Rule{ + &mRuleCallDefinition, + &mRuleEncryptedRoomOneToOneDefinition, + &mRuleRoomOneToOneDefinition, + &mRuleMessageDefinition, + &mRuleEncryptedDefinition, +} + +var ( + mRuleCallDefinition = Rule{ + RuleID: MRuleCall, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.call.invite", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: SoundTweak, + Value: "ring", + }, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleEncryptedRoomOneToOneDefinition = Rule{ + RuleID: MRuleEncryptedRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleRoomOneToOneDefinition = Rule{ + RuleID: MRuleRoomOneToOne, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: RoomMemberCountCondition, + Is: "2", + }, + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{ + {Kind: NotifyAction}, + { + Kind: SetTweakAction, + Tweak: HighlightTweak, + Value: false, + }, + }, + } + mRuleMessageDefinition = Rule{ + RuleID: MRuleMessage, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.message", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } + mRuleEncryptedDefinition = Rule{ + RuleID: MRuleEncrypted, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: "m.room.encrypted", + }, + }, + Actions: []*Action{{Kind: NotifyAction}}, + } +) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go new file mode 100644 index 00000000..df22cb04 --- /dev/null +++ b/internal/pushrules/evaluate.go @@ -0,0 +1,165 @@ +package pushrules + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A RuleSetEvaluator encapsulates context to evaluate an event +// against a rule set. +type RuleSetEvaluator struct { + ec EvaluationContext + ruleSet []kindAndRules +} + +// An EvaluationContext gives a RuleSetEvaluator access to the +// environment, for rules that require that. +type EvaluationContext interface { + // UserDisplayName returns the current user's display name. + UserDisplayName() string + + // RoomMemberCount returns the number of members in the room of + // the current event. + RoomMemberCount() (int, error) + + // HasPowerLevel returns whether the user has at least the given + // power in the room of the current event. + HasPowerLevel(userID, levelKey string) (bool, error) +} + +// A kindAndRules is just here to simplify iteration of the (ordered) +// kinds of rules. +type kindAndRules struct { + Kind Kind + Rules []*Rule +} + +// NewRuleSetEvaluator creates a new evaluator for the given rule set. +func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator { + return &RuleSetEvaluator{ + ec: ec, + ruleSet: []kindAndRules{ + {OverrideKind, ruleSet.Override}, + {ContentKind, ruleSet.Content}, + {RoomKind, ruleSet.Room}, + {SenderKind, ruleSet.Sender}, + {UnderrideKind, ruleSet.Underride}, + }, + } +} + +// MatchEvent returns the first matching rule. Returns nil if there +// was no match rule. +func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) { + // TODO: server-default rules have lower priority than user rules, + // but they are stored together with the user rules. It's a bit + // unclear what the specification (11.14.1.4 Predefined rules) + // means the ordering should be. + // + // The most reasonable interpretation is that default overrides + // still have lower priority than user content rules, so we + // iterate twice. + for _, rsat := range rse.ruleSet { + for _, defRules := range []bool{false, true} { + for _, rule := range rsat.Rules { + if rule.Default != defRules { + continue + } + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + if err != nil { + return nil, err + } + if ok { + return rule, nil + } + } + } + } + + // No matching rule. + return nil, nil +} + +func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + if !rule.Enabled { + return false, nil + } + + switch kind { + case OverrideKind, UnderrideKind: + for _, cond := range rule.Conditions { + ok, err := conditionMatches(cond, event, ec) + if err != nil { + return false, err + } + if !ok { + return false, nil + } + } + return true, nil + + case ContentKind: + // TODO: "These configure behaviour for (unencrypted) messages + // that match certain patterns." - Does that mean "content.body"? + return patternMatches("content.body", rule.Pattern, event) + + case RoomKind: + return rule.RuleID == event.RoomID(), nil + + case SenderKind: + return rule.RuleID == event.Sender(), nil + + default: + return false, nil + } +} + +func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { + switch cond.Kind { + case EventMatchCondition: + return patternMatches(cond.Key, cond.Pattern, event) + + case ContainsDisplayNameCondition: + return patternMatches("content.body", ec.UserDisplayName(), event) + + case RoomMemberCountCondition: + cmp, err := parseRoomMemberCountCondition(cond.Is) + if err != nil { + return false, fmt.Errorf("parsing room_member_count condition: %w", err) + } + n, err := ec.RoomMemberCount() + if err != nil { + return false, fmt.Errorf("RoomMemberCount failed: %w", err) + } + return cmp(n), nil + + case SenderNotificationPermissionCondition: + return ec.HasPowerLevel(event.Sender(), cond.Key) + + default: + return false, nil + } +} + +func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { + re, err := globToRegexp(pattern) + if err != nil { + return false, err + } + + var eventMap map[string]interface{} + if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { + return false, fmt.Errorf("parsing event: %w", err) + } + v, err := lookupMapPath(strings.Split(key, "."), eventMap) + if err != nil { + // An unknown path is a benign error that shouldn't stop rule + // processing. It's just a non-match. + return false, nil + } + + return re.MatchString(fmt.Sprint(v)), nil +} diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go new file mode 100644 index 00000000..50e70336 --- /dev/null +++ b/internal/pushrules/evaluate_test.go @@ -0,0 +1,189 @@ +package pushrules + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/matrix-org/gomatrixserverlib" +) + +func TestRuleSetEvaluatorMatchEvent(t *testing.T) { + ev := mustEventFromJSON(t, `{}`) + defaultEnabled := &Rule{ + RuleID: ".default.enabled", + Default: true, + Enabled: true, + } + userEnabled := &Rule{ + RuleID: ".user.enabled", + Default: false, + Enabled: true, + } + userEnabled2 := &Rule{ + RuleID: ".user.enabled.2", + Default: false, + Enabled: true, + } + tsts := []struct { + Name string + RuleSet RuleSet + Want *Rule + }{ + {"empty", RuleSet{}, nil}, + {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled}, + {"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled}, + {"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled}, + {"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled}, + {"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled}, + {"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled}, + {"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + rse := NewRuleSetEvaluator(nil, &tst.RuleSet) + got, err := rse.MatchEvent(ev) + if err != nil { + t.Fatalf("MatchEvent failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("MatchEvent rule: +got -want:\n%s", diff) + } + }) + } +} + +func TestRuleMatches(t *testing.T) { + emptyRule := Rule{Enabled: true} + tsts := []struct { + Name string + Kind Kind + Rule Rule + EventJSON string + Want bool + }{ + {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, + {"emptyContent", ContentKind, emptyRule, `{}`, false}, + {"emptyRoom", RoomKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, + + {"disabled", OverrideKind, Rule{}, `{}`, false}, + + {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true}, + {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, + {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, + + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + if err != nil { + t.Fatalf("ruleMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("ruleMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestConditionMatches(t *testing.T) { + tsts := []struct { + Name string + Cond Condition + EventJSON string + Want bool + }{ + {"empty", Condition{}, `{}`, false}, + {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + + {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, + + {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, + {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + + {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, + {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, + {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, + {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, + {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, + {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, + {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, + {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, + {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, + {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + + {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, + {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{}) + if err != nil { + t.Fatalf("conditionMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +type fakeEvaluationContext struct{} + +func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" } +func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil } +func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) { + return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil +} + +func TestPatternMatches(t *testing.T) { + tsts := []struct { + Name string + Key string + Pattern string + EventJSON string + Want bool + }{ + {"empty", "", "", `{}`, false}, + + // Note that an empty pattern contains no wildcard characters, + // which implicitly means "*". + {"patternEmpty", "content", "", `{"content":{}}`, true}, + + {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, + {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, + {"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true}, + {"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true}, + {"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON)) + if err != nil { + t.Fatalf("patternMatches failed: %v", err) + } + if got != tst.Want { + t.Errorf("patternMatches: got %v, want %v", got, tst.Want) + } + }) + } +} + +func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event { + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7) + if err != nil { + t.Fatal(err) + } + return ev +} diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go new file mode 100644 index 00000000..bbed1f95 --- /dev/null +++ b/internal/pushrules/pushrules.go @@ -0,0 +1,71 @@ +package pushrules + +// An AccountRuleSets carries the rule sets associated with an +// account. +type AccountRuleSets struct { + Global RuleSet `json:"global"` // Required +} + +// A RuleSet contains all the various push rules for an +// account. Listed in decreasing order of priority. +type RuleSet struct { + Override []*Rule `json:"override,omitempty"` + Content []*Rule `json:"content,omitempty"` + Room []*Rule `json:"room,omitempty"` + Sender []*Rule `json:"sender,omitempty"` + Underride []*Rule `json:"underride,omitempty"` +} + +// A Rule contains matchers, conditions and final actions. While +// evaluating, at most one rule is considered matching. +// +// Kind and scope are part of the push rules request/responses, but +// not of the core data model. +type Rule struct { + // RuleID is either a free identifier, or the sender's MXID for + // SenderKind. Required. + RuleID string `json:"rule_id"` + + // Default indicates whether this is a server-defined default, or + // a user-provided rule. Required. + // + // The server-default rules have the lowest priority. + Default bool `json:"default"` + + // Enabled allows the user to disable rules while keeping them + // around. Required. + Enabled bool `json:"enabled"` + + // Actions describe the desired outcome, should the rule + // match. Required. + Actions []*Action `json:"actions"` + + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions"` + + // Pattern is the body pattern to match for ContentKind. Required + // for that kind. The interpretation is the same as that of + // Condition.Pattern. + Pattern string `json:"pattern"` +} + +// Scope only has one valid value. See also AccountRuleSets. +type Scope string + +const ( + UnknownScope Scope = "" + GlobalScope Scope = "global" +) + +// Kind is the type of push rule. See also RuleSet. +type Kind string + +const ( + UnknownKind Kind = "" + OverrideKind Kind = "override" + ContentKind Kind = "content" + RoomKind Kind = "room" + SenderKind Kind = "sender" + UnderrideKind Kind = "underride" +) diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go new file mode 100644 index 00000000..027d35ef --- /dev/null +++ b/internal/pushrules/util.go @@ -0,0 +1,125 @@ +package pushrules + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// ActionsToTweaks converts a list of actions into a primary action +// kind and a tweaks map. Returns a nil map if it would have been +// empty. +func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { + var kind ActionKind + tweaks := map[string]interface{}{} + + for _, a := range as { + if a.Kind == SetTweakAction { + tweaks[string(a.Tweak)] = a.Value + continue + } + if kind != UnknownAction { + return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) + } + kind = a.Kind + } + + if len(tweaks) == 0 { + tweaks = nil + } + + return kind, tweaks, nil +} + +// BoolTweakOr returns the named tweak as a boolean, and returns `def` +// on failure. +func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool { + v, ok := tweaks[string(key)] + if !ok { + return def + } + b, ok := v.(bool) + if !ok { + return def + } + return b +} + +// globToRegexp converts a Matrix glob-style pattern to a Regular expression. +func globToRegexp(pattern string) (*regexp.Regexp, error) { + // TODO: It's unclear which glob characters are supported. The only + // place this is discussed is for the unrelated "m.policy.rule.*" + // events. Assuming, the same: /[*?]/ + if !strings.ContainsAny(pattern, "*?") { + pattern = "*" + pattern + "*" + } + + // The defined syntax doesn't allow escaping the glob wildcard + // characters, which makes this a straight-forward + // replace-after-quote. + pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta) + pattern = strings.Replace(pattern, "*", ".*", -1) + pattern = strings.Replace(pattern, "?", ".", -1) + return regexp.Compile("^(" + pattern + ")$") +} + +// globNonMetaRegexp are the characters that are not considered glob +// meta-characters (i.e. may need escaping). +var globNonMetaRegexp = regexp.MustCompile("[^*?]+") + +// lookupMapPath traverses a hierarchical map structure, like the one +// produced by json.Unmarshal, to return the leaf value. Traversing +// arrays/slices is not supported, only objects/maps. +func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) { + if len(path) == 0 { + return nil, fmt.Errorf("empty path") + } + + var v interface{} = m + for i, key := range path { + m, ok := v.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v) + } + + v, ok = m[key] + if !ok { + return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], ".")) + } + } + + return v, nil +} + +// parseRoomMemberCountCondition parses a string like "2", "==2", "<2" +// into a function that checks if the argument to it fulfils the +// condition. +func parseRoomMemberCountCondition(s string) (func(int) bool, error) { + var b int + var cmp = func(a int) bool { return a == b } + switch { + case strings.HasPrefix(s, "<="): + cmp = func(a int) bool { return a <= b } + s = s[2:] + case strings.HasPrefix(s, ">="): + cmp = func(a int) bool { return a >= b } + s = s[2:] + case strings.HasPrefix(s, "<"): + cmp = func(a int) bool { return a < b } + s = s[1:] + case strings.HasPrefix(s, ">"): + cmp = func(a int) bool { return a > b } + s = s[1:] + case strings.HasPrefix(s, "=="): + // Same cmp as the default. + s = s[2:] + } + + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + b = int(v) + return cmp, nil +} diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go new file mode 100644 index 00000000..a951c55a --- /dev/null +++ b/internal/pushrules/util_test.go @@ -0,0 +1,169 @@ +package pushrules + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestActionsToTweaks(t *testing.T) { + tsts := []struct { + Name string + Input []*Action + WantKind ActionKind + WantTweaks map[string]interface{} + }{ + {"empty", nil, UnknownAction, nil}, + {"zero", []*Action{{}}, UnknownAction, nil}, + {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, + {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, + {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, + { + "all", + []*Action{ + {Kind: CoalesceAction}, + {Kind: SetTweakAction, Tweak: HighlightTweak}, + {Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}, + }, + CoalesceAction, + map[string]interface{}{"highlight": nil, "sound": "default"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + gotKind, gotTweaks, err := ActionsToTweaks(tst.Input) + if err != nil { + t.Fatalf("ActionsToTweaks failed: %v", err) + } + if gotKind != tst.WantKind { + t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind) + } + if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" { + t.Errorf("tweaks: +got -want:\n%s", diff) + } + }) + } +} + +func TestBoolTweakOr(t *testing.T) { + tsts := []struct { + Name string + Input map[string]interface{} + Def bool + Want bool + }{ + {"nil", nil, false, false}, + {"nilValue", map[string]interface{}{"highlight": nil}, true, true}, + {"false", map[string]interface{}{"highlight": false}, true, false}, + {"true", map[string]interface{}{"highlight": true}, false, true}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def) + if got != tst.Want { + t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want) + } + }) + } +} + +func TestGlobToRegexp(t *testing.T) { + tsts := []struct { + Input string + Want string + }{ + {"", "^(.*.*)$"}, + {"a", "^(.*a.*)$"}, + {"a.b", "^(.*a\\.b.*)$"}, + {"a?b", "^(a.b)$"}, + {"a*b*", "^(a.*b.*)$"}, + {"a*b?", "^(a.*b.)$"}, + } + for _, tst := range tsts { + t.Run(tst.Want, func(t *testing.T) { + got, err := globToRegexp(tst.Input) + if err != nil { + t.Fatalf("globToRegexp failed: %v", err) + } + if got.String() != tst.Want { + t.Errorf("got %v, want %v", got.String(), tst.Want) + } + }) + } +} + +func TestLookupMapPath(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + Want interface{} + }{ + {[]string{"a"}, map[string]interface{}{"a": "b"}, "b"}, + {[]string{"a"}, map[string]interface{}{"a": 42}, 42}, + {[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err != nil { + t.Fatalf("lookupMapPath failed: %v", err) + } + if diff := cmp.Diff(tst.Want, got); diff != "" { + t.Errorf("+got -want:\n%s", diff) + } + }) + } +} + +func TestLookupMapPathInvalid(t *testing.T) { + tsts := []struct { + Path []string + Root map[string]interface{} + }{ + {nil, nil}, + {[]string{"a"}, nil}, + {[]string{"a", "b"}, map[string]interface{}{"a": "c"}}, + } + for _, tst := range tsts { + t.Run(fmt.Sprint(tst.Path), func(t *testing.T) { + got, err := lookupMapPath(tst.Path, tst.Root) + if err == nil { + t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got) + } + }) + } +} + +func TestParseRoomMemberCountCondition(t *testing.T) { + tsts := []struct { + Input string + WantTrue []int + WantFalse []int + }{ + {"1", []int{1}, []int{0, 2}}, + {"==1", []int{1}, []int{0, 2}}, + {"<1", []int{0}, []int{1, 2}}, + {"<=1", []int{0, 1}, []int{2}}, + {">1", []int{2}, []int{0, 1}}, + {">=42", []int{42, 43}, []int{41}}, + } + for _, tst := range tsts { + t.Run(tst.Input, func(t *testing.T) { + got, err := parseRoomMemberCountCondition(tst.Input) + if err != nil { + t.Fatalf("parseRoomMemberCountCondition failed: %v", err) + } + for _, v := range tst.WantTrue { + if !got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v) + } + } + for _, v := range tst.WantFalse { + if got(v) { + t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v) + } + } + }) + } +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go new file mode 100644 index 00000000..5d260f0b --- /dev/null +++ b/internal/pushrules/validate.go @@ -0,0 +1,85 @@ +package pushrules + +import ( + "fmt" + "regexp" +) + +// ValidateRule checks the rule for errors. These follow from Sytests +// and the specification. +func ValidateRule(kind Kind, rule *Rule) []error { + var errs []error + + if !validRuleIDRE.MatchString(rule.RuleID) { + errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID)) + } + + if len(rule.Actions) == 0 { + errs = append(errs, fmt.Errorf("missing actions")) + } + for _, action := range rule.Actions { + errs = append(errs, validateAction(action)...) + } + + for _, cond := range rule.Conditions { + errs = append(errs, validateCondition(cond)...) + } + + switch kind { + case OverrideKind, UnderrideKind: + // The empty list is allowed, but for JSON-encoding reasons, + // it must not be nil. + if rule.Conditions == nil { + errs = append(errs, fmt.Errorf("missing rule conditions")) + } + + case ContentKind: + if rule.Pattern == "" { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + + case RoomKind, SenderKind: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind)) + } + + return errs +} + +// validRuleIDRE is a regexp for valid IDs. +// +// TODO: the specification doesn't seem to say what the rule ID syntax +// is. A Sytest fails if it contains a backslash. +var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`) + +// validateAction returns issues with an Action. +func validateAction(action *Action) []error { + var errs []error + + switch action.Kind { + case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind)) + } + + return errs +} + +// validateCondition returns issues with a Condition. +func validateCondition(cond *Condition) []error { + var errs []error + + switch cond.Kind { + case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition: + // Do nothing. + + default: + errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind)) + } + + return errs +} diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go new file mode 100644 index 00000000..b276eb55 --- /dev/null +++ b/internal/pushrules/validate_test.go @@ -0,0 +1,163 @@ +package pushrules + +import ( + "strings" + "testing" +) + +func TestValidateRuleNegatives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantErrString string + }{ + {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, + {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, + {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, + {"noActions", OverrideKind, Rule{}, "missing actions"}, + {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, + {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, + {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, + {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, + {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateRulePositives(t *testing.T) { + tsts := []struct { + Name string + Kind Kind + Rule Rule + WantNoErrString string + }{ + {"invalidKind", OverrideKind, Rule{}, "invalid rule kind"}, + {"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"}, + {"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"}, + {"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"}, + {"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"}, + {"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"}, + {"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"}, + {"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := ValidateRule(tst.Kind, &tst.Rule) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateActionNegatives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantErrString string + }{ + {"emptyKind", Action{}, "invalid rule action kind"}, + {"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateActionPositives(t *testing.T) { + tsts := []struct { + Name string + Action Action + WantNoErrString string + }{ + {"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateAction(&tst.Action) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} + +func TestValidateConditionNegatives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantErrString string + }{ + {"emptyKind", Condition{}, "invalid rule condition kind"}, + {"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + var foundErr error + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantErrString) { + foundErr = err + } + } + if foundErr == nil { + t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString) + } + }) + } +} + +func TestValidateConditionPositives(t *testing.T) { + tsts := []struct { + Name string + Condition Condition + WantNoErrString string + }{ + {"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"}, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + errs := validateCondition(&tst.Condition) + for _, err := range errs { + t.Logf("Got error %#v", err) + if strings.Contains(err.Error(), tst.WantNoErrString) { + t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString) + } + } + }) + } +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 8d0d2dfa..19483b26 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -163,6 +163,7 @@ type StatementList []struct { func (s StatementList) Prepare(db *sql.DB) (err error) { for _, statement := range s { if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { + err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL) return } } |