aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-06-06 20:55:18 +0000
committerGitHub <noreply@github.com>2023-06-06 20:55:18 +0000
commit7a1fd7f512ce06a472a2051ee63eae4a270eb71a (patch)
tree20128b0d3f7c69dd776aa7b2b9bc3194dda7dd75 /userapi
parent725ff5567d2a3bc9992b065e72ccabefb595ec1c (diff)
PDU Sender split (#3100)
Initial cut of splitting PDU Sender into SenderID & looking up UserID where required.
Diffstat (limited to 'userapi')
-rw-r--r--userapi/consumers/roomserver.go43
-rw-r--r--userapi/consumers/roomserver_test.go11
-rw-r--r--userapi/util/notify_test.go9
3 files changed, 48 insertions, 15 deletions
diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go
index 3cfdc0ce..c025deee 100644
--- a/userapi/consumers/roomserver.go
+++ b/userapi/consumers/roomserver.go
@@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
}
if s.cfg.Matrix.ReportStats.Enabled {
- go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID())
+ go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID())
}
log.WithFields(log.Fields{
@@ -301,7 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch {
case event.Type() == spec.MRoomMember:
- cevent := synctypes.ToClientEvent(event, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if queryErr == nil && userID != nil {
+ sender = *userID
+ }
+ cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
@@ -529,12 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
return fmt.Errorf("s.localPushDevices: %w", err)
}
+ sender := spec.UserID{}
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
- Event: synctypes.ToClientEvent(event, synctypes.FormatSync),
+ Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it
@@ -615,7 +625,12 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
- if event.Sender() == mem.UserID {
+ user := ""
+ sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil {
+ user = sender.String()
+ }
+ if user == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves.
return nil, nil
@@ -632,9 +647,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
if err != nil {
return nil, err
}
- sender := event.Sender()
- if _, ok := ignored.List[sender]; ok {
- return nil, fmt.Errorf("user %s is ignored", sender)
+ if _, ok := ignored.List[sender.String()]; ok {
+ return nil, fmt.Errorf("user %s is ignored", sender.String())
}
}
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
@@ -650,7 +664,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize,
}
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
- rule, err := eval.MatchEvent(event.PDU)
+ rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) {
+ return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ })
if err != nil {
return nil, err
}
@@ -682,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
-func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
+func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) {
req := &rsapi.QueryLatestEventsAndStateRequest{
RoomID: rse.roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{
@@ -702,7 +718,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
if err != nil {
return false, err
}
- return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
+ return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil
}
return true, nil
}
@@ -756,6 +772,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
}
default:
+ sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err != nil {
+ logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
+ return nil, err
+ }
req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Content: event.Content(),
@@ -767,7 +788,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
ID: event.EventID(),
RoomID: event.RoomID(),
RoomName: roomName,
- Sender: event.Sender(),
+ Sender: sender.String(),
Type: event.Type(),
},
}
diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go
index 53977206..899a5aaf 100644
--- a/userapi/consumers/roomserver_test.go
+++ b/userapi/consumers/roomserver_test.go
@@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/pushrules"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage"
@@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
return &types.HeaderedEvent{PDU: ev}
}
+type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
+
+func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return spec.NewUserID(senderID, true)
+}
+
func Test_evaluatePushRules(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
- consumer := OutputRoomEventConsumer{db: db}
+ consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}}
testCases := []struct {
name string
@@ -86,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
},
{
name: "m.room.message highlights",
- eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`,
+ eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`,
wantNotify: true,
wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{
diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go
index e1c88d47..27dd373c 100644
--- a/userapi/util/notify_test.go
+++ b/userapi/util/notify_test.go
@@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"golang.org/x/crypto/bcrypt"
@@ -87,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Prepare pusher with our test server URL
- if err := db.UpsertPusher(ctx, api.Pusher{
+ if err = db.UpsertPusher(ctx, api.Pusher{
Kind: api.HTTPKind,
AppID: appID,
PushKey: pushKey,
@@ -99,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Insert a dummy event
+ sender, err := spec.NewUserID(alice.ID, true)
+ if err != nil {
+ t.Error(err)
+ }
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
- Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll),
+ Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
}); err != nil {
t.Error(err)
}