diff options
author | devonh <devon.dmytro@gmail.com> | 2023-06-14 14:23:46 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-14 14:23:46 +0000 |
commit | e4665979bfbe006368d55189f074e456fe19b198 (patch) | |
tree | e909d694a022478d0dbe3cc58ee8a2dc289bc969 /internal | |
parent | 7a2e325d1014d76188b47a011730a42443f3c174 (diff) |
Merge SenderID & Per Room User Key work (#3109)
Diffstat (limited to 'internal')
-rw-r--r-- | internal/pushrules/evaluate.go | 6 | ||||
-rw-r--r-- | internal/pushrules/evaluate_test.go | 8 | ||||
-rw-r--r-- | internal/transactionrequest.go | 2 | ||||
-rw-r--r-- | internal/transactionrequest_test.go | 4 |
4 files changed, 12 insertions, 8 deletions
diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index ac760895..28dea97c 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -115,7 +115,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati case SenderKind: userID := "" - sender, err := userIDForSender(event.RoomID(), event.SenderID()) + validRoomID, err := spec.NewRoomID(event.RoomID()) + if err != nil { + return false, err + } + sender, err := userIDForSender(*validRoomID, event.SenderID()) if err == nil { userID = sender.String() } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 859d1f8a..a4ccc3d0 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -8,7 +8,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" ) -func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -73,7 +73,7 @@ func TestRuleMatches(t *testing.T) { {"emptyOverride", OverrideKind, emptyRule, `{}`, true}, {"emptyContent", ContentKind, emptyRule, `{}`, false}, {"emptyRoom", RoomKind, emptyRule, `{}`, true}, - {"emptySender", SenderKind, emptyRule, `{}`, true}, + {"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true}, {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true}, {"disabled", OverrideKind, Rule{}, `{}`, false}, @@ -90,8 +90,8 @@ func TestRuleMatches(t *testing.T) { {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index b2929bb5..5bf7d819 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index 1d32c806..ffc1cd89 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,7 +70,7 @@ type FakeRsAPI struct { bannedFromRoom bool } -func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } @@ -642,7 +642,7 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } -func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return spec.NewUserID(string(senderID), true) } |