From 249b32c4f3ee2e01e6f89435e0c7a5786d2ae3a1 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:01:34 +0200 Subject: Refactor notifications (#2688) This PR changes the handling of notifications - removes the `StreamEvent` and `ReadUpdate` stream - listens on the `OutputRoomEvent` stream in the UserAPI to inform the SyncAPI about unread notifications - listens on the `OutputReceiptEvent` stream in the UserAPI to set receipts/update notifications - sets the `read_markers` directly from within the internal UserAPI Co-authored-by: Neil Alexander --- setup/jetstream/nats.go | 5 +- setup/jetstream/streams.go | 10 - syncapi/consumers/clientapi.go | 46 -- syncapi/consumers/receipts.go | 48 +- syncapi/consumers/roomserver.go | 17 +- syncapi/consumers/userapi.go | 5 +- syncapi/producers/userapi_readupdate.go | 62 --- syncapi/producers/userapi_streamevent.go | 60 -- syncapi/storage/interface.go | 15 +- .../storage/postgres/notification_data_table.go | 36 +- syncapi/storage/shared/syncserver.go | 11 +- syncapi/storage/sqlite3/notification_data_table.go | 39 +- syncapi/storage/tables/interface.go | 2 +- syncapi/streams/stream_accountdata.go | 3 +- syncapi/streams/stream_notificationdata.go | 23 +- syncapi/syncapi.go | 14 +- syncapi/types/types.go | 23 +- userapi/consumers/clientapi.go | 127 +++++ userapi/consumers/roomserver.go | 614 +++++++++++++++++++++ userapi/consumers/roomserver_test.go | 129 +++++ userapi/consumers/syncapi_readupdate.go | 137 ----- userapi/consumers/syncapi_streamevent.go | 606 -------------------- userapi/consumers/syncapi_streamevent_test.go | 129 ----- userapi/internal/api.go | 46 ++ userapi/producers/syncapi.go | 7 +- userapi/storage/interface.go | 6 +- userapi/storage/postgres/notifications_table.go | 51 +- userapi/storage/postgres/pusher_table.go | 5 +- userapi/storage/shared/storage.go | 6 +- userapi/storage/sqlite3/notifications_table.go | 51 +- userapi/storage/sqlite3/pusher_table.go | 5 +- userapi/storage/storage_test.go | 11 +- userapi/storage/tables/interface.go | 6 +- userapi/userapi.go | 11 +- 34 files changed, 1068 insertions(+), 1298 deletions(-) delete mode 100644 syncapi/producers/userapi_readupdate.go delete mode 100644 syncapi/producers/userapi_streamevent.go create mode 100644 userapi/consumers/clientapi.go create mode 100644 userapi/consumers/roomserver.go create mode 100644 userapi/consumers/roomserver_test.go delete mode 100644 userapi/consumers/syncapi_readupdate.go delete mode 100644 userapi/consumers/syncapi_streamevent.go delete mode 100644 userapi/consumers/syncapi_streamevent_test.go diff --git a/setup/jetstream/nats.go b/setup/jetstream/nats.go index 3660e91e..7409fd6c 100644 --- a/setup/jetstream/nats.go +++ b/setup/jetstream/nats.go @@ -9,9 +9,10 @@ import ( "time" "github.com/getsentry/sentry-go" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" - "github.com/sirupsen/logrus" natsserver "github.com/nats-io/nats-server/v2/server" natsclient "github.com/nats-io/nats.go" @@ -184,6 +185,8 @@ func setupNATS(process *process.ProcessContext, cfg *config.JetStream, nc *natsc OutputSendToDeviceEvent: {"SyncAPIEDUServerSendToDeviceConsumer", "FederationAPIEDUServerConsumer"}, OutputTypingEvent: {"SyncAPIEDUServerTypingConsumer", "FederationAPIEDUServerConsumer"}, OutputRoomEvent: {"AppserviceRoomserverConsumer"}, + OutputStreamEvent: {"UserAPISyncAPIStreamEventConsumer"}, + OutputReadUpdate: {"UserAPISyncAPIReadUpdateConsumer"}, } { streamName := cfg.Matrix.JetStream.Prefixed(stream) for _, consumer := range consumers { diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go index c07d3a0b..ee9810da 100644 --- a/setup/jetstream/streams.go +++ b/setup/jetstream/streams.go @@ -94,16 +94,6 @@ var streams = []*nats.StreamConfig{ Retention: nats.InterestPolicy, Storage: nats.FileStorage, }, - { - Name: OutputStreamEvent, - Retention: nats.InterestPolicy, - Storage: nats.FileStorage, - }, - { - Name: OutputReadUpdate, - Retention: nats.InterestPolicy, - Storage: nats.FileStorage, - }, { Name: OutputPresenceEvent, Retention: nats.InterestPolicy, diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index f0588cab..a170a6ec 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -16,9 +16,7 @@ package consumers import ( "context" - "database/sql" "encoding/json" - "fmt" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -31,7 +29,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" ) @@ -46,7 +43,6 @@ type OutputClientDataConsumer struct { stream types.StreamProvider notifier *notifier.Notifier serverName gomatrixserverlib.ServerName - producer *producers.UserAPIReadProducer } // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. @@ -57,7 +53,6 @@ func NewOutputClientDataConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, - producer *producers.UserAPIReadProducer, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ ctx: process.Context(), @@ -68,7 +63,6 @@ func NewOutputClientDataConsumer( notifier: notifier, stream: stream, serverName: cfg.Matrix.ServerName, - producer: producer, } } @@ -113,15 +107,6 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M return false } - if err = s.sendReadUpdate(ctx, userID, output); err != nil { - log.WithError(err).WithFields(logrus.Fields{ - "user_id": userID, - "room_id": output.RoomID, - }).Errorf("Failed to generate read update") - sentry.CaptureException(err) - return false - } - if output.IgnoredUsers != nil { if err := s.db.UpdateIgnoresForUser(ctx, userID, output.IgnoredUsers); err != nil { log.WithError(err).WithFields(logrus.Fields{ @@ -136,34 +121,3 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msgs []*nats.M return true } - -func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error { - if output.Type != "m.fully_read" || output.ReadMarker == nil { - return nil - } - _, serverName, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if serverName != s.serverName { - return nil - } - var readPos types.StreamPosition - var fullyReadPos types.StreamPosition - if output.ReadMarker.Read != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) - } - } - if output.ReadMarker.FullyRead != "" { - if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err) - } - } - if readPos > 0 || fullyReadPos > 0 { - if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil { - return fmt.Errorf("s.producer.SendReadUpdate: %w", err) - } - } - return nil -} diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index a18244c4..4379dd13 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -16,22 +16,19 @@ package consumers import ( "context" - "database/sql" - "fmt" "strconv" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" - log "github.com/sirupsen/logrus" ) // OutputReceiptEventConsumer consumes events that originated in the EDU server. @@ -44,7 +41,6 @@ type OutputReceiptEventConsumer struct { stream types.StreamProvider notifier *notifier.Notifier serverName gomatrixserverlib.ServerName - producer *producers.UserAPIReadProducer } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -56,7 +52,6 @@ func NewOutputReceiptEventConsumer( store storage.Database, notifier *notifier.Notifier, stream types.StreamProvider, - producer *producers.UserAPIReadProducer, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ ctx: process.Context(), @@ -67,7 +62,6 @@ func NewOutputReceiptEventConsumer( notifier: notifier, stream: stream, serverName: cfg.Matrix.ServerName, - producer: producer, } } @@ -111,42 +105,8 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return true } - if err = s.sendReadUpdate(ctx, output); err != nil { - log.WithError(err).WithFields(logrus.Fields{ - "user_id": output.UserID, - "room_id": output.RoomID, - }).Errorf("Failed to generate read update") - sentry.CaptureException(err) - return false - } - s.stream.Advance(streamPos) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) return true } - -func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output types.OutputReceiptEvent) error { - if output.Type != "m.read" { - return nil - } - _, serverName, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if serverName != s.serverName { - return nil - } - var readPos types.StreamPosition - if output.EventID != "" { - if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil && err != sql.ErrNoRows { - return fmt.Errorf("s.db.PositionInTopology (Read): %w", err) - } - } - if readPos > 0 { - if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil { - return fmt.Errorf("s.producer.SendReadUpdate: %w", err) - } - } - return nil -} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 6979eb48..0964ae20 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -21,17 +21,17 @@ import ( "fmt" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/syncapi/notifier" - "github.com/matrix-org/dendrite/syncapi/producers" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputRoomEventConsumer consumes events that originated in the room server. @@ -46,7 +46,6 @@ type OutputRoomEventConsumer struct { pduStream types.StreamProvider inviteStream types.StreamProvider notifier *notifier.Notifier - producer *producers.UserAPIStreamEventProducer } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -59,7 +58,6 @@ func NewOutputRoomEventConsumer( pduStream types.StreamProvider, inviteStream types.StreamProvider, rsAPI api.SyncRoomserverAPI, - producer *producers.UserAPIStreamEventProducer, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ ctx: process.Context(), @@ -72,7 +70,6 @@ func NewOutputRoomEventConsumer( pduStream: pduStream, inviteStream: inviteStream, rsAPI: rsAPI, - producer: producer, } } @@ -255,12 +252,6 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( return nil } - if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil { - log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID()) - sentry.CaptureException(err) - return err - } - if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) sentry.CaptureException(err) diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go index 22782352..c9b96f78 100644 --- a/syncapi/consumers/userapi.go +++ b/syncapi/consumers/userapi.go @@ -19,6 +19,9 @@ import ( "encoding/json" "github.com/getsentry/sentry-go" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" @@ -26,8 +29,6 @@ import ( "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" ) // OutputNotificationDataConsumer consumes events that originated in diff --git a/syncapi/producers/userapi_readupdate.go b/syncapi/producers/userapi_readupdate.go deleted file mode 100644 index d56cab77..00000000 --- a/syncapi/producers/userapi_readupdate.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "encoding/json" - - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -// UserAPIProducer produces events for the user API server to consume -type UserAPIReadProducer struct { - Topic string - JetStream nats.JetStreamContext -} - -// SendData sends account data to the user API server -func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error { - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.UserID, userID) - m.Header.Set(jetstream.RoomID, roomID) - - data := types.ReadUpdate{ - UserID: userID, - RoomID: roomID, - Read: readPos, - FullyRead: fullyReadPos, - } - var err error - m.Data, err = json.Marshal(data) - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "user_id": userID, - "room_id": roomID, - "read_pos": readPos, - "fully_read_pos": fullyReadPos, - }).Tracef("Producing to topic '%s'", p.Topic) - - _, err = p.JetStream.PublishMsg(m) - return err -} diff --git a/syncapi/producers/userapi_streamevent.go b/syncapi/producers/userapi_streamevent.go deleted file mode 100644 index 2bbd19c0..00000000 --- a/syncapi/producers/userapi_streamevent.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package producers - -import ( - "encoding/json" - - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -// UserAPIProducer produces events for the user API server to consume -type UserAPIStreamEventProducer struct { - Topic string - JetStream nats.JetStreamContext -} - -// SendData sends account data to the user API server -func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error { - m := &nats.Msg{ - Subject: p.Topic, - Header: nats.Header{}, - } - m.Header.Set(jetstream.RoomID, roomID) - - data := types.StreamedEvent{ - Event: event, - StreamPosition: pos, - } - var err error - m.Data, err = json.Marshal(data) - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "room_id": roomID, - "event_id": event.EventID(), - "event_type": event.Type(), - "stream_pos": pos, - }).Tracef("Producing to topic '%s'", p.Topic) - - _, err = p.JetStream.PublishMsg(m) - return err -} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 0c8ba4e3..ad3be420 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -29,6 +29,7 @@ import ( type Database interface { Presence SharedUsers + Notifications MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) @@ -148,12 +149,6 @@ type Database interface { // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]types.OutputReceiptEvent, error) - // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. - UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) - - // GetUserUnreadNotificationCounts returns statistics per room a user is interested in. - GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) - SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) @@ -179,3 +174,11 @@ type SharedUsers interface { // SharedUsers returns a subset of otherUserIDs that share a room with userID. SharedUsers(ctx context.Context, userID string, otherUserIDs []string) ([]string, error) } + +type Notifications interface { + // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key. + UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) + + // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms + GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) +} diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 708c3a9b..2c7b2480 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -18,6 +18,8 @@ import ( "context" "database/sql" + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -33,15 +35,15 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro r := ¬ificationDataStatements{} return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, - {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, + {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, }.Prepare(db) } type notificationDataStatements struct { - upsertRoomUnreadCounts *sql.Stmt - selectUserUnreadCounts *sql.Stmt - selectMaxID *sql.Stmt + upsertRoomUnreadCounts *sql.Stmt + selectUserUnreadCountsForRooms *sql.Stmt + selectMaxID *sql.Stmt } const notificationDataSchema = ` @@ -61,12 +63,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_ DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4 RETURNING id` -const selectUserUnreadNotificationCountsSQL = `SELECT - id, room_id, notification_count, highlight_count - FROM syncapi_notification_data - WHERE - user_id = $1 AND - id BETWEEN $2 + 1 AND $3` +const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE user_id = $1 AND + room_id = ANY($2)` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` @@ -75,20 +75,20 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, return } -func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) +func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( + ctx context.Context, txn *sql.Tx, userID string, roomIDs []string, +) (map[string]*eventutil.NotificationData, error) { + rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCountsForRooms).QueryContext(ctx, userID, pq.Array(roomIDs)) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed") roomCounts := map[string]*eventutil.NotificationData{} + var roomID string + var notificationCount, highlightCount int for rows.Next() { - var id types.StreamPosition - var roomID string - var notificationCount, highlightCount int - - if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil { return nil, err } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 778ad8b1..215bad3a 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1036,8 +1036,15 @@ func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userI return } -func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - return d.NotificationData.SelectUserUnreadCounts(ctx, nil, userID, from, to) +func (d *Database) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, rooms map[string]string) (map[string]*eventutil.NotificationData, error) { + roomIDs := make([]string, 0, len(rooms)) + for roomID, membership := range rooms { + if membership != gomatrixserverlib.Join { + continue + } + roomIDs = append(roomIDs, roomID) + } + return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, nil, userID, roomIDs) } func (d *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 66d4d438..ceff6055 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -32,19 +33,21 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t } r := ¬ificationDataStatements{ streamIDStatements: streamID, + db: db, } return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, - {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } type notificationDataStatements struct { + db *sql.DB streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt - selectUserUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + //selectUserUnreadCountsForRooms *sql.Stmt } const notificationDataSchema = ` @@ -63,12 +66,10 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_ ON CONFLICT (user_id, room_id) DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7` -const selectUserUnreadNotificationCountsSQL = `SELECT - id, room_id, notification_count, highlight_count - FROM syncapi_notification_data - WHERE - user_id = $1 AND - id BETWEEN $2 + 1 AND $3` +const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_count, highlight_count + FROM syncapi_notification_data + WHERE user_id = $1 AND + room_id IN ($2)` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` @@ -81,20 +82,26 @@ func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, return } -func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) { - rows, err := sqlutil.TxStmt(txn, r.selectUserUnreadCounts).QueryContext(ctx, userID, fromExcl, toIncl) +func (r *notificationDataStatements) SelectUserUnreadCountsForRooms( + ctx context.Context, txn *sql.Tx, userID string, roomIDs []string, +) (map[string]*eventutil.NotificationData, error) { + params := make([]interface{}, len(roomIDs)+1) + params[0] = userID + for i := range roomIDs { + params[i+1] = roomIDs[i] + } + sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($1)", sqlutil.QueryVariadic(len(params)), 1) + rows, err := r.db.QueryContext(ctx, sql, params) if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed") + defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCountsForRooms: rows.close() failed") roomCounts := map[string]*eventutil.NotificationData{} + var roomID string + var notificationCount, highlightCount int for rows.Next() { - var id types.StreamPosition - var roomID string - var notificationCount, highlightCount int - - if err = rows.Scan(&id, &roomID, ¬ificationCount, &highlightCount); err != nil { + if err = rows.Scan(&roomID, ¬ificationCount, &highlightCount); err != nil { return nil, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 193881b4..9a873c2e 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -190,7 +190,7 @@ type Memberships interface { type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) - SelectUserUnreadCounts(ctx context.Context, txn *sql.Tx, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) + SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) } diff --git a/syncapi/streams/stream_accountdata.go b/syncapi/streams/stream_accountdata.go index 9c19b846..0297d5c2 100644 --- a/syncapi/streams/stream_accountdata.go +++ b/syncapi/streams/stream_accountdata.go @@ -3,9 +3,10 @@ package streams import ( "context" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) type AccountDataStreamProvider struct { diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go index 8ba9e07c..33872734 100644 --- a/syncapi/streams/stream_notificationdata.go +++ b/syncapi/streams/stream_notificationdata.go @@ -30,26 +30,29 @@ func (p *NotificationDataStreamProvider) CompleteSync( func (p *NotificationDataStreamProvider) IncrementalSync( ctx context.Context, req *types.SyncRequest, - from, to types.StreamPosition, + from, _ types.StreamPosition, ) types.StreamPosition { - // We want counts for all possible rooms, so always start from zero. - countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to) + // Get the unread notifications for rooms in our join response. + // This is to ensure clients always have an unread notification section + // and can display the correct numbers. + countsByRoom, err := p.DB.GetUserUnreadNotificationCountsForRooms(ctx, req.Device.UserID, req.Rooms) if err != nil { - req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed") + req.Log.WithError(err).Error("GetUserUnreadNotificationCountsForRooms failed") return from } - // We're merely decorating existing rooms. Note that the Join map - // values are not pointers. + // We're merely decorating existing rooms. for roomID, jr := range req.Response.Rooms.Join { counts := countsByRoom[roomID] if counts == nil { continue } - - jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount - jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount + jr.UnreadNotifications = &types.UnreadNotifications{ + HighlightCount: counts.UnreadHighlightCount, + NotificationCount: counts.UnreadNotificationCount, + } req.Response.Rooms.Join[roomID] = jr } - return to + + return p.LatestPosition(ctx) } diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 68537bc4..f5d00f36 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -77,16 +77,6 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start presence consumer") } - userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{ - JetStream: js, - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), - } - - userAPIReadUpdateProducer := &producers.UserAPIReadProducer{ - JetStream: js, - Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate), - } - keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( base.ProcessContext, cfg, cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent), js, rsAPI, syncDB, notifier, @@ -98,7 +88,7 @@ func AddPublicRoutes( roomConsumer := consumers.NewOutputRoomEventConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.PDUStreamProvider, - streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer, + streams.InviteStreamProvider, rsAPI, ) if err = roomConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start room server consumer") @@ -106,7 +96,6 @@ func AddPublicRoutes( clientConsumer := consumers.NewOutputClientDataConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, - userAPIReadUpdateProducer, ) if err = clientConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start client data consumer") @@ -135,7 +124,6 @@ func AddPublicRoutes( receiptConsumer := consumers.NewOutputReceiptEventConsumer( base.ProcessContext, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, - userAPIReadUpdateProducer, ) if err = receiptConsumer.Start(); err != nil { logrus.WithError(err).Panicf("failed to start receipts consumer") diff --git a/syncapi/types/types.go b/syncapi/types/types.go index d75d53ca..3b85db4a 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -398,6 +398,11 @@ func (r *Response) IsEmpty() bool { len(r.ToDevice.Events) == 0 } +type UnreadNotifications struct { + HighlightCount int `json:"highlight_count"` + NotificationCount int `json:"notification_count"` +} + // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key. type JoinResponse struct { Summary struct { @@ -419,10 +424,7 @@ type JoinResponse struct { AccountData struct { Events []gomatrixserverlib.ClientEvent `json:"events"` } `json:"account_data"` - UnreadNotifications struct { - HighlightCount int `json:"highlight_count"` - NotificationCount int `json:"notification_count"` - } `json:"unread_notifications"` + *UnreadNotifications `json:"unread_notifications,omitempty"` } // NewJoinResponse creates an empty response with initialised arrays. @@ -503,19 +505,6 @@ type Peek struct { Deleted bool } -type ReadUpdate struct { - UserID string `json:"user_id"` - RoomID string `json:"room_id"` - Read StreamPosition `json:"read,omitempty"` - FullyRead StreamPosition `json:"fully_read,omitempty"` -} - -// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. -type StreamedEvent struct { - Event *gomatrixserverlib.HeaderedEvent `json:"event"` - StreamPosition StreamPosition `json:"stream_position"` -} - // OutputReceiptEvent is an entry in the receipt output kafka log type OutputReceiptEvent struct { UserID string `json:"user_id"` diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go new file mode 100644 index 00000000..c220d35c --- /dev/null +++ b/userapi/consumers/clientapi.go @@ -0,0 +1,127 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consumers + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/storage" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/util" +) + +// OutputReceiptEventConsumer consumes events that originated in the clientAPI. +type OutputReceiptEventConsumer struct { + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + serverName gomatrixserverlib.ServerName + syncProducer *producers.SyncAPI + pgClient pushgateway.Client +} + +// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. +// Call Start() to begin consuming from the EDU server. +func NewOutputReceiptEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + syncProducer *producers.SyncAPI, + pgClient pushgateway.Client, +) *OutputReceiptEventConsumer { + return &OutputReceiptEventConsumer{ + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("UserAPIReceiptConsumer"), + db: store, + serverName: cfg.Matrix.ServerName, + syncProducer: syncProducer, + pgClient: pgClient, + } +} + +// Start consuming receipts events. +func (s *OutputReceiptEventConsumer) Start() error { + return jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), + ) +} + +func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userID := msg.Header.Get(jetstream.UserID) + roomID := msg.Header.Get(jetstream.RoomID) + readPos := msg.Header.Get(jetstream.EventID) + evType := msg.Header.Get("type") + + if readPos == "" || evType != "m.read" { + return true + } + + log := log.WithFields(log.Fields{ + "room_id": roomID, + "user_id": userID, + }) + + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + log.WithError(err).Error("userapi clientapi consumer: SplitID failure") + return true + } + if domain != s.serverName { + return true + } + + metadata, err := msg.Metadata() + if err != nil { + return false + } + + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + if err != nil { + log.WithError(err).Error("userapi EDU consumer") + return false + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { + log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") + return false + } + + if !updated { + return true + } + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") + return false + } + + return true +} diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go new file mode 100644 index 00000000..952de98f --- /dev/null +++ b/userapi/consumers/roomserver.go @@ -0,0 +1,614 @@ +package consumers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/util" +) + +type OutputRoomEventConsumer struct { + ctx context.Context + cfg *config.UserAPI + rsAPI rsapi.UserRoomserverAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + topic string + pgClient pushgateway.Client + syncProducer *producers.SyncAPI +} + +func NewOutputRoomEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + rsAPI rsapi.UserRoomserverAPI, + syncProducer *producers.SyncAPI, +) *OutputRoomEventConsumer { + return &OutputRoomEventConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + durable: cfg.Matrix.JetStream.Durable("UserAPIRoomServerConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + pgClient: pgClient, + rsAPI: rsAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputRoomEventConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, 1, + s.onMessage, nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + var output rsapi.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true + } + if output.Type != rsapi.OutputTypeNewRoomEvent { + return true + } + event := output.NewRoomEvent.Event + if event == nil { + log.Errorf("userapi consumer: expected event") + return true + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "event_type": event.Type(), + }).Tracef("Received message from roomserver: %#v", output) + + metadata, err := msg.Metadata() + if err != nil { + return true + } + + if err := s.processMessage(ctx, event, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp))); err != nil { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + }).WithError(err).Errorf("userapi consumer: process room event failure") + } + + return true +} + +func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { + members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) + if err != nil { + return fmt.Errorf("s.localRoomMembers: %w", err) + } + + if event.Type() == gomatrixserverlib.MRoomMember { + cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) + var member *localMembership + member, err = newLocalMembership(&cevent) + if err != nil { + return fmt.Errorf("newLocalMembership: %w", err) + } + if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { + // localRoomMembers only adds joined members. An invite + // should also be pushed to the target user. + members = append(members, member) + } + } + + // TODO: run in parallel with localRoomMembers. + roomName, err := s.roomName(ctx, event) + if err != nil { + return fmt.Errorf("s.roomName: %w", err) + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "num_members": len(members), + "room_size": roomSize, + }).Tracef("Notifying members") + + // Notification.UserIsTarget is a per-member field, so we + // cannot group all users in a single request. + // + // TODO: does it have to be set? It's not required, and + // removing it means we can send all notifications to + // e.g. Element's Push gateway in one go. + for _, mem := range members { + if err := s.notifyLocal(ctx, event, mem, roomSize, roomName, streamPos); err != nil { + log.WithFields(log.Fields{ + "localpart": mem.Localpart, + }).WithError(err).Error("Unable to push to local user") + continue + } + } + + return nil +} + +type localMembership struct { + gomatrixserverlib.MemberContent + UserID string + Localpart string + Domain gomatrixserverlib.ServerName +} + +func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) { + if event.StateKey == nil { + return nil, fmt.Errorf("missing state_key") + } + + var member localMembership + if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil { + return nil, err + } + + localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + return nil, err + } + + member.UserID = *event.StateKey + member.Localpart = localpart + member.Domain = domain + return &member, nil +} + +// localRoomMembers fetches the current local members of a room, and +// the total number of members. +func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { + req := &rsapi.QueryMembershipsForRoomRequest{ + RoomID: roomID, + JoinedOnly: true, + LocalOnly: true, + } + var res rsapi.QueryMembershipsForRoomResponse + + // XXX: This could potentially race if the state for the event is not known yet + // e.g. the event came over federation but we do not have the full state persisted. + if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { + return nil, 0, err + } + + var members []*localMembership + var ntotal int + for _, event := range res.JoinEvents { + member, err := newLocalMembership(&event) + if err != nil { + log.WithError(err).Errorf("Parsing MemberContent") + continue + } + if member.Membership != gomatrixserverlib.Join { + continue + } + if member.Domain != s.cfg.Matrix.ServerName { + continue + } + + ntotal++ + members = append(members, member) + } + + return members, ntotal, nil +} + +// roomName returns the name in the event (if type==m.room.name), or +// looks it up in roomserver. If there is no name, +// m.room.canonical_alias is consulted. Returns an empty string if the +// room has no name. +func (s *OutputRoomEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { + if event.Type() == gomatrixserverlib.MRoomName { + name, err := unmarshalRoomName(event) + if err != nil { + return "", err + } + + if name != "" { + return name, nil + } + } + + req := &rsapi.QueryCurrentStateRequest{ + RoomID: event.RoomID(), + StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple}, + } + var res rsapi.QueryCurrentStateResponse + + if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil { + return "", nil + } + + if eventS := res.StateEvents[roomNameTuple]; eventS != nil { + return unmarshalRoomName(eventS) + } + + if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { + alias, err := unmarshalCanonicalAlias(event) + if err != nil { + return "", err + } + + if alias != "" { + return alias, nil + } + } + + if event = res.StateEvents[canonicalAliasTuple]; event != nil { + return unmarshalCanonicalAlias(event) + } + + return "", nil +} + +var ( + canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} + roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} +) + +func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var nc eventutil.NameContent + if err := json.Unmarshal(event.Content(), &nc); err != nil { + return "", fmt.Errorf("unmarshaling NameContent: %w", err) + } + + return nc.Name, nil +} + +func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var cac eventutil.CanonicalAliasContent + if err := json.Unmarshal(event.Content(), &cac); err != nil { + return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err) + } + + return cac.Alias, nil +} + +// notifyLocal finds the right push actions for a local user, given an event. +func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { + actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) + if err != nil { + return err + } + a, tweaks, err := pushrules.ActionsToTweaks(actions) + if err != nil { + return err + } + // TODO: support coalescing. + if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + }).Tracef("Push rule evaluation rejected the event") + return nil + } + + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + if err != nil { + return err + } + + n := &api.Notification{ + Actions: actions, + // UNSPEC: the spec doesn't say this is a ClientEvent, but the + // fields seem to match. room_id should be missing, which + // matches the behaviour of FormatSync. + Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync), + // TODO: this is per-device, but it's not part of the primary + // key. So inserting one notification per profile tag doesn't + // make sense. What is this supposed to be? Sytests require it + // to "work", but they only use a single device. + ProfileTag: profileTag, + RoomID: event.RoomID(), + TS: gomatrixserverlib.AsTimestamp(time.Now()), + } + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { + return err + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { + return err + } + + // We do this after InsertNotification. Thus, this should always return >=1. + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "num_urls": len(devicesByURLAndFormat), + "num_unread": userNumUnreadNotifs, + }).Trace("Notifying single member") + + // Push gateways are out of our control, and we cannot risk + // looking up the server on a misbehaving push gateway. Each user + // receives a goroutine now that all internal API calls have been + // made. + // + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var rejected []*pushgateway.Device + for url, fmts := range devicesByURLAndFormat { + for format, devices := range fmts { + // TODO: support "email". + if !strings.HasPrefix(url, "http") { + continue + } + + // UNSPEC: the specification suggests there can be + // more than one device per request. There is at least + // one Sytest that expects one HTTP request per + // device, rather than per URL. For now, we must + // notify each one separately. + for _, dev := range devices { + rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs)) + if err != nil { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "localpart": mem.Localpart, + }).WithError(err).Errorf("Unable to notify HTTP pusher") + continue + } + rejected = append(rejected, rej...) + } + } + } + + if len(rejected) > 0 { + s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + } + }() + + return nil +} + +// evaluatePushRules fetches and evaluates the push rules of a local +// user. Returns actions (including dont_notify). +func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { + if event.Sender() == mem.UserID { + // SPEC: Homeservers MUST NOT notify the Push Gateway for + // events that the user has sent themselves. + return nil, nil + } + + // Get accountdata to check if the event.Sender() is ignored by mem.LocalPart + data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list") + if err != nil { + return nil, err + } + if data != nil { + ignored := types.IgnoredUsers{} + err = json.Unmarshal(data, &ignored) + if err != nil { + return nil, err + } + sender := event.Sender() + if _, ok := ignored.List[sender]; ok { + return nil, fmt.Errorf("user %s is ignored", sender) + } + } + ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart) + if err != nil { + return nil, err + } + + ec := &ruleSetEvalContext{ + ctx: ctx, + rsAPI: s.rsAPI, + mem: mem, + roomID: event.RoomID(), + roomSize: roomSize, + } + eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) + rule, err := eval.MatchEvent(event.Event) + if err != nil { + return nil, err + } + if rule == nil { + // SPEC: If no rules match an event, the homeserver MUST NOT + // notify the Push Gateway for that event. + return nil, err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "rule_id": rule.RuleID, + }).Trace("Matched a push rule") + + return rule.Actions, nil +} + +type ruleSetEvalContext struct { + ctx context.Context + rsAPI rsapi.UserRoomserverAPI + mem *localMembership + roomID string + roomSize int +} + +func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName } + +func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } + +func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { + req := &rsapi.QueryLatestEventsAndStateRequest{ + RoomID: rse.roomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomPowerLevels}, + }, + } + var res rsapi.QueryLatestEventsAndStateResponse + if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil { + return false, err + } + for _, ev := range res.StateEvents { + if ev.Type() != gomatrixserverlib.MRoomPowerLevels { + continue + } + + plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event) + if err != nil { + return false, err + } + return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + } + return true, nil +} + +// localPushDevices pushes to the configured devices of a local +// user. The map keys are [url][format]. +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) + if err != nil { + return nil, "", err + } + + var profileTag string + devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices)) + for _, pusherDevice := range pusherDevices { + if profileTag == "" { + profileTag = pusherDevice.Pusher.ProfileTag + } + + url := pusherDevice.URL + if devicesByURL[url] == nil { + devicesByURL[url] = make(map[string][]*pushgateway.Device, 2) + } + devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device) + } + + return devicesByURL, profileTag, nil +} + +// notifyHTTP performs a notificatation to a Push Gateway. +func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { + logger := log.WithFields(log.Fields{ + "event_id": event.EventID(), + "url": url, + "localpart": localpart, + "num_devices": len(devices), + }) + + var req pushgateway.NotifyRequest + switch format { + case "event_id_only": + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{ + Unread: userNumUnreadNotifs, + }, + Devices: devices, + EventID: event.EventID(), + RoomID: event.RoomID(), + }, + } + + default: + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Content: event.Content(), + Counts: &pushgateway.Counts{ + Unread: userNumUnreadNotifs, + }, + Devices: devices, + EventID: event.EventID(), + ID: event.EventID(), + RoomID: event.RoomID(), + RoomName: roomName, + Sender: event.Sender(), + Type: event.Type(), + }, + } + if mem, err := event.Membership(); err == nil { + req.Notification.Membership = mem + } + if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + req.Notification.UserIsTarget = true + } + } + + logger.Tracef("Notifying push gateway %s", url) + var res pushgateway.NotifyResponse + if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { + logger.WithError(err).Errorf("Failed to notify push gateway %s", url) + return nil, err + } + logger.WithField("num_rejected", len(res.Rejected)).Trace("Push gateway result") + + if len(res.Rejected) == 0 { + return nil, nil + } + + devMap := make(map[string]*pushgateway.Device, len(devices)) + for _, d := range devices { + devMap[d.PushKey] = d + } + rejected := make([]*pushgateway.Device, 0, len(res.Rejected)) + for _, pushKey := range res.Rejected { + d := devMap[pushKey] + if d != nil { + rejected = append(rejected, d) + } + } + + return rejected, nil +} + +// deleteRejectedPushers deletes the pushers associated with the given devices. +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": devices[0].AppID, + "num_devices": len(devices), + }).Warnf("Deleting pushers rejected by the HTTP push gateway") + + for _, d := range devices { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + }).WithError(err).Errorf("Unable to delete rejected pusher") + } + } +} diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go new file mode 100644 index 00000000..3bbeb439 --- /dev/null +++ b/userapi/consumers/roomserver_test.go @@ -0,0 +1,129 @@ +package consumers + +import ( + "context" + "testing" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/internal/pushrules" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "", 4, 0, 0, "") + if err != nil { + t.Fatalf("failed to create new user db: %v", err) + } + return db, close +} + +func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent { + t.Helper() + ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) + if err != nil { + t.Fatalf("failed to create event: %v", err) + } + return ev.Headered(gomatrixserverlib.RoomVersionV10) +} + +func Test_evaluatePushRules(t *testing.T) { + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateDatabase(t, dbType) + defer close() + consumer := OutputRoomEventConsumer{db: db} + + testCases := []struct { + name string + eventContent string + wantAction pushrules.ActionKind + wantActions []*pushrules.Action + wantNotify bool + }{ + { + name: "m.receipt doesn't notify", + eventContent: `{"type":"m.receipt"}`, + wantAction: pushrules.UnknownAction, + wantActions: nil, + }, + { + name: "m.reaction doesn't notify", + eventContent: `{"type":"m.reaction"}`, + wantAction: pushrules.DontNotifyAction, + wantActions: []*pushrules.Action{ + { + Kind: pushrules.DontNotifyAction, + }, + }, + }, + { + name: "m.room.message notifies", + eventContent: `{"type":"m.room.message"}`, + wantNotify: true, + wantAction: pushrules.NotifyAction, + wantActions: []*pushrules.Action{ + {Kind: pushrules.NotifyAction}, + { + Kind: pushrules.SetTweakAction, + Tweak: pushrules.HighlightTweak, + Value: false, + }, + }, + }, + { + name: "m.room.message highlights", + eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`, + wantNotify: true, + wantAction: pushrules.NotifyAction, + wantActions: []*pushrules.Action{ + {Kind: pushrules.NotifyAction}, + { + Kind: pushrules.SetTweakAction, + Tweak: pushrules.SoundTweak, + Value: "default", + }, + { + Kind: pushrules.SetTweakAction, + Tweak: pushrules.HighlightTweak, + Value: true, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actions, err := consumer.evaluatePushRules(ctx, mustCreateEvent(t, tc.eventContent), &localMembership{ + UserID: "@test:localhost", + Localpart: "test", + Domain: "localhost", + }, 10) + if err != nil { + t.Fatalf("failed to evaluate push rules: %v", err) + } + assert.Equal(t, tc.wantActions, actions) + gotAction, _, err := pushrules.ActionsToTweaks(actions) + if err != nil { + t.Fatalf("failed to get actions: %v", err) + } + if gotAction != tc.wantAction { + t.Fatalf("expected action to be '%s', got '%s'", tc.wantAction, gotAction) + } + // this is taken from `notifyLocal` + if tc.wantNotify && gotAction != pushrules.NotifyAction && gotAction != pushrules.CoalesceAction { + t.Fatalf("expected to notify but didn't") + } + }) + + } + }) +} diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go deleted file mode 100644 index 54654f75..00000000 --- a/userapi/consumers/syncapi_readupdate.go +++ /dev/null @@ -1,137 +0,0 @@ -package consumers - -import ( - "context" - "encoding/json" - - "github.com/matrix-org/dendrite/internal/pushgateway" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/syncapi/types" - uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/producers" - "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/dendrite/userapi/util" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" -) - -type OutputReadUpdateConsumer struct { - ctx context.Context - cfg *config.UserAPI - jetstream nats.JetStreamContext - durable string - db storage.Database - pgClient pushgateway.Client - ServerName gomatrixserverlib.ServerName - topic string - userAPI uapi.UserInternalAPI - syncProducer *producers.SyncAPI -} - -func NewOutputReadUpdateConsumer( - process *process.ProcessContext, - cfg *config.UserAPI, - js nats.JetStreamContext, - store storage.Database, - pgClient pushgateway.Client, - userAPI uapi.UserInternalAPI, - syncProducer *producers.SyncAPI, -) *OutputReadUpdateConsumer { - return &OutputReadUpdateConsumer{ - ctx: process.Context(), - cfg: cfg, - jetstream: js, - db: store, - ServerName: cfg.Matrix.ServerName, - durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReadUpdate), - pgClient: pgClient, - userAPI: userAPI, - syncProducer: syncProducer, - } -} - -func (s *OutputReadUpdateConsumer) Start() error { - if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, 1, - s.onMessage, nats.DeliverAll(), nats.ManualAck(), - ); err != nil { - return err - } - return nil -} - -func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var read types.ReadUpdate - if err := json.Unmarshal(msg.Data, &read); err != nil { - log.WithError(err).Error("userapi clientapi consumer: message parse failure") - return true - } - if read.FullyRead == 0 && read.Read == 0 { - return true - } - - userID := string(msg.Header.Get(jetstream.UserID)) - roomID := string(msg.Header.Get(jetstream.RoomID)) - - localpart, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - log.WithError(err).Error("userapi clientapi consumer: SplitID failure") - return true - } - if domain != s.ServerName { - log.Error("userapi clientapi consumer: not a local user") - return true - } - - log := log.WithFields(log.Fields{ - "room_id": roomID, - "user_id": userID, - }) - log.Tracef("Received read update from sync API: %#v", read) - - if read.Read > 0 { - updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) - if err != nil { - log.WithError(err).Error("userapi EDU consumer") - return false - } - - if updated { - if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { - log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") - return false - } - if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { - log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") - return false - } - } - } - - if read.FullyRead > 0 { - deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) - if err != nil { - log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") - return false - } - - if deleted { - if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { - log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed") - return false - } - - if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { - log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") - return false - } - } - } - - return true -} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go deleted file mode 100644 index f3b2bf27..00000000 --- a/userapi/consumers/syncapi_streamevent.go +++ /dev/null @@ -1,606 +0,0 @@ -package consumers - -import ( - "context" - "encoding/json" - "fmt" - "strings" - "time" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - log "github.com/sirupsen/logrus" - - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/internal/pushgateway" - "github.com/matrix-org/dendrite/internal/pushrules" - rsapi "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/setup/process" - "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/producers" - "github.com/matrix-org/dendrite/userapi/storage" - "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/dendrite/userapi/util" -) - -type OutputStreamEventConsumer struct { - ctx context.Context - cfg *config.UserAPI - rsAPI rsapi.UserRoomserverAPI - jetstream nats.JetStreamContext - durable string - db storage.Database - topic string - pgClient pushgateway.Client - syncProducer *producers.SyncAPI -} - -func NewOutputStreamEventConsumer( - process *process.ProcessContext, - cfg *config.UserAPI, - js nats.JetStreamContext, - store storage.Database, - pgClient pushgateway.Client, - rsAPI rsapi.UserRoomserverAPI, - syncProducer *producers.SyncAPI, -) *OutputStreamEventConsumer { - return &OutputStreamEventConsumer{ - ctx: process.Context(), - cfg: cfg, - jetstream: js, - db: store, - durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputStreamEvent), - pgClient: pgClient, - rsAPI: rsAPI, - syncProducer: syncProducer, - } -} - -func (s *OutputStreamEventConsumer) Start() error { - if err := jetstream.JetStreamConsumer( - s.ctx, s.jetstream, s.topic, s.durable, 1, - s.onMessage, nats.DeliverAll(), nats.ManualAck(), - ); err != nil { - return err - } - return nil -} - -func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool { - msg := msgs[0] // Guaranteed to exist if onMessage is called - var output types.StreamedEvent - output.Event = &gomatrixserverlib.HeaderedEvent{} - if err := json.Unmarshal(msg.Data, &output); err != nil { - log.WithError(err).Errorf("userapi consumer: message parse failure") - return true - } - if output.Event.Event == nil { - log.Errorf("userapi consumer: expected event") - return true - } - - log.WithFields(log.Fields{ - "event_id": output.Event.EventID(), - "event_type": output.Event.Type(), - "stream_pos": output.StreamPosition, - }).Tracef("Received message from sync API: %#v", output) - - if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { - log.WithFields(log.Fields{ - "event_id": output.Event.EventID(), - }).WithError(err).Errorf("userapi consumer: process room event failure") - } - - return true -} - -func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { - members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) - if err != nil { - return fmt.Errorf("s.localRoomMembers: %w", err) - } - - if event.Type() == gomatrixserverlib.MRoomMember { - cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) - var member *localMembership - member, err = newLocalMembership(&cevent) - if err != nil { - return fmt.Errorf("newLocalMembership: %w", err) - } - if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { - // localRoomMembers only adds joined members. An invite - // should also be pushed to the target user. - members = append(members, member) - } - } - - // TODO: run in parallel with localRoomMembers. - roomName, err := s.roomName(ctx, event) - if err != nil { - return fmt.Errorf("s.roomName: %w", err) - } - - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "room_id": event.RoomID(), - "num_members": len(members), - "room_size": roomSize, - }).Tracef("Notifying members") - - // Notification.UserIsTarget is a per-member field, so we - // cannot group all users in a single request. - // - // TODO: does it have to be set? It's not required, and - // removing it means we can send all notifications to - // e.g. Element's Push gateway in one go. - for _, mem := range members { - if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { - log.WithFields(log.Fields{ - "localpart": mem.Localpart, - }).WithError(err).Debugf("Unable to push to local user") - continue - } - } - - return nil -} - -type localMembership struct { - gomatrixserverlib.MemberContent - UserID string - Localpart string - Domain gomatrixserverlib.ServerName -} - -func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) { - if event.StateKey == nil { - return nil, fmt.Errorf("missing state_key") - } - - var member localMembership - if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil { - return nil, err - } - - localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey) - if err != nil { - return nil, err - } - - member.UserID = *event.StateKey - member.Localpart = localpart - member.Domain = domain - return &member, nil -} - -// localRoomMembers fetches the current local members of a room, and -// the total number of members. -func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { - req := &rsapi.QueryMembershipsForRoomRequest{ - RoomID: roomID, - JoinedOnly: true, - LocalOnly: true, - } - var res rsapi.QueryMembershipsForRoomResponse - - // XXX: This could potentially race if the state for the event is not known yet - // e.g. the event came over federation but we do not have the full state persisted. - if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { - return nil, 0, err - } - - var members []*localMembership - var ntotal int - for _, event := range res.JoinEvents { - member, err := newLocalMembership(&event) - if err != nil { - log.WithError(err).Errorf("Parsing MemberContent") - continue - } - if member.Membership != gomatrixserverlib.Join { - continue - } - if member.Domain != s.cfg.Matrix.ServerName { - continue - } - - ntotal++ - members = append(members, member) - } - - return members, ntotal, nil -} - -// roomName returns the name in the event (if type==m.room.name), or -// looks it up in roomserver. If there is no name, -// m.room.canonical_alias is consulted. Returns an empty string if the -// room has no name. -func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { - if event.Type() == gomatrixserverlib.MRoomName { - name, err := unmarshalRoomName(event) - if err != nil { - return "", err - } - - if name != "" { - return name, nil - } - } - - req := &rsapi.QueryCurrentStateRequest{ - RoomID: event.RoomID(), - StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple}, - } - var res rsapi.QueryCurrentStateResponse - - if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil { - return "", nil - } - - if eventS := res.StateEvents[roomNameTuple]; eventS != nil { - return unmarshalRoomName(eventS) - } - - if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { - alias, err := unmarshalCanonicalAlias(event) - if err != nil { - return "", err - } - - if alias != "" { - return alias, nil - } - } - - if event = res.StateEvents[canonicalAliasTuple]; event != nil { - return unmarshalCanonicalAlias(event) - } - - return "", nil -} - -var ( - canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} - roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} -) - -func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { - var nc eventutil.NameContent - if err := json.Unmarshal(event.Content(), &nc); err != nil { - return "", fmt.Errorf("unmarshaling NameContent: %w", err) - } - - return nc.Name, nil -} - -func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) { - var cac eventutil.CanonicalAliasContent - if err := json.Unmarshal(event.Content(), &cac); err != nil { - return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err) - } - - return cac.Alias, nil -} - -// notifyLocal finds the right push actions for a local user, given an event. -func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { - actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) - if err != nil { - return err - } - a, tweaks, err := pushrules.ActionsToTweaks(actions) - if err != nil { - return err - } - // TODO: support coalescing. - if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "room_id": event.RoomID(), - "localpart": mem.Localpart, - }).Debugf("Push rule evaluation rejected the event") - return nil - } - - devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) - if err != nil { - return err - } - - n := &api.Notification{ - Actions: actions, - // UNSPEC: the spec doesn't say this is a ClientEvent, but the - // fields seem to match. room_id should be missing, which - // matches the behaviour of FormatSync. - Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync), - // TODO: this is per-device, but it's not part of the primary - // key. So inserting one notification per profile tag doesn't - // make sense. What is this supposed to be? Sytests require it - // to "work", but they only use a single device. - ProfileTag: profileTag, - RoomID: event.RoomID(), - TS: gomatrixserverlib.AsTimestamp(time.Now()), - } - if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { - return err - } - - if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { - return err - } - - // We do this after InsertNotification. Thus, this should always return >=1. - userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) - if err != nil { - return err - } - - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "room_id": event.RoomID(), - "localpart": mem.Localpart, - "num_urls": len(devicesByURLAndFormat), - "num_unread": userNumUnreadNotifs, - }).Debugf("Notifying single member") - - // Push gateways are out of our control, and we cannot risk - // looking up the server on a misbehaving push gateway. Each user - // receives a goroutine now that all internal API calls have been - // made. - // - // TODO: think about bounding this to one per user, and what - // ordering guarantees we must provide. - go func() { - // This background processing cannot be tied to a request. - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - var rejected []*pushgateway.Device - for url, fmts := range devicesByURLAndFormat { - for format, devices := range fmts { - // TODO: support "email". - if !strings.HasPrefix(url, "http") { - continue - } - - // UNSPEC: the specification suggests there can be - // more than one device per request. There is at least - // one Sytest that expects one HTTP request per - // device, rather than per URL. For now, we must - // notify each one separately. - for _, dev := range devices { - rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs)) - if err != nil { - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "localpart": mem.Localpart, - }).WithError(err).Errorf("Unable to notify HTTP pusher") - continue - } - rejected = append(rejected, rej...) - } - } - } - - if len(rejected) > 0 { - s.deleteRejectedPushers(ctx, rejected, mem.Localpart) - } - }() - - return nil -} - -// evaluatePushRules fetches and evaluates the push rules of a local -// user. Returns actions (including dont_notify). -func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { - if event.Sender() == mem.UserID { - // SPEC: Homeservers MUST NOT notify the Push Gateway for - // events that the user has sent themselves. - return nil, nil - } - - // Get accountdata to check if the event.Sender() is ignored by mem.LocalPart - data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list") - if err != nil { - return nil, err - } - if data != nil { - ignored := types.IgnoredUsers{} - err = json.Unmarshal(data, &ignored) - if err != nil { - return nil, err - } - sender := event.Sender() - if _, ok := ignored.List[sender]; ok { - return nil, fmt.Errorf("user %s is ignored", sender) - } - } - ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart) - if err != nil { - return nil, err - } - - ec := &ruleSetEvalContext{ - ctx: ctx, - rsAPI: s.rsAPI, - mem: mem, - roomID: event.RoomID(), - roomSize: roomSize, - } - eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.Event) - if err != nil { - return nil, err - } - if rule == nil { - // SPEC: If no rules match an event, the homeserver MUST NOT - // notify the Push Gateway for that event. - return nil, err - } - - log.WithFields(log.Fields{ - "event_id": event.EventID(), - "room_id": event.RoomID(), - "localpart": mem.Localpart, - "rule_id": rule.RuleID, - }).Tracef("Matched a push rule") - - return rule.Actions, nil -} - -type ruleSetEvalContext struct { - ctx context.Context - rsAPI rsapi.UserRoomserverAPI - mem *localMembership - roomID string - roomSize int -} - -func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName } - -func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } - -func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { - req := &rsapi.QueryLatestEventsAndStateRequest{ - RoomID: rse.roomID, - StateToFetch: []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels}, - }, - } - var res rsapi.QueryLatestEventsAndStateResponse - if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil { - return false, err - } - for _, ev := range res.StateEvents { - if ev.Type() != gomatrixserverlib.MRoomPowerLevels { - continue - } - - plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event) - if err != nil { - return false, err - } - return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil - } - return true, nil -} - -// localPushDevices pushes to the configured devices of a local -// user. The map keys are [url][format]. -func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { - pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) - if err != nil { - return nil, "", err - } - - var profileTag string - devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices)) - for _, pusherDevice := range pusherDevices { - if profileTag == "" { - profileTag = pusherDevice.Pusher.ProfileTag - } - - url := pusherDevice.URL - if devicesByURL[url] == nil { - devicesByURL[url] = make(map[string][]*pushgateway.Device, 2) - } - devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device) - } - - return devicesByURL, profileTag, nil -} - -// notifyHTTP performs a notificatation to a Push Gateway. -func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { - logger := log.WithFields(log.Fields{ - "event_id": event.EventID(), - "url": url, - "localpart": localpart, - "num_devices": len(devices), - }) - - var req pushgateway.NotifyRequest - switch format { - case "event_id_only": - req = pushgateway.NotifyRequest{ - Notification: pushgateway.Notification{ - Counts: &pushgateway.Counts{ - Unread: userNumUnreadNotifs, - }, - Devices: devices, - EventID: event.EventID(), - RoomID: event.RoomID(), - }, - } - - default: - req = pushgateway.NotifyRequest{ - Notification: pushgateway.Notification{ - Content: event.Content(), - Counts: &pushgateway.Counts{ - Unread: userNumUnreadNotifs, - }, - Devices: devices, - EventID: event.EventID(), - ID: event.EventID(), - RoomID: event.RoomID(), - RoomName: roomName, - Sender: event.Sender(), - Type: event.Type(), - }, - } - if mem, err := event.Membership(); err == nil { - req.Notification.Membership = mem - } - if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { - req.Notification.UserIsTarget = true - } - } - - logger.Debugf("Notifying push gateway %s", url) - var res pushgateway.NotifyResponse - if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { - logger.WithError(err).Errorf("Failed to notify push gateway %s", url) - return nil, err - } - logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") - - if len(res.Rejected) == 0 { - return nil, nil - } - - devMap := make(map[string]*pushgateway.Device, len(devices)) - for _, d := range devices { - devMap[d.PushKey] = d - } - rejected := make([]*pushgateway.Device, 0, len(res.Rejected)) - for _, pushKey := range res.Rejected { - d := devMap[pushKey] - if d != nil { - rejected = append(rejected, d) - } - } - - return rejected, nil -} - -// deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { - log.WithFields(log.Fields{ - "localpart": localpart, - "app_id0": devices[0].AppID, - "num_devices": len(devices), - }).Warnf("Deleting pushers rejected by the HTTP push gateway") - - for _, d := range devices { - if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { - log.WithFields(log.Fields{ - "localpart": localpart, - }).WithError(err).Errorf("Unable to delete rejected pusher") - } - } -} diff --git a/userapi/consumers/syncapi_streamevent_test.go b/userapi/consumers/syncapi_streamevent_test.go deleted file mode 100644 index 48ea0fe1..00000000 --- a/userapi/consumers/syncapi_streamevent_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package consumers - -import ( - "context" - "testing" - - "github.com/matrix-org/gomatrixserverlib" - "github.com/stretchr/testify/assert" - - "github.com/matrix-org/dendrite/internal/pushrules" - "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/test" - "github.com/matrix-org/dendrite/userapi/storage" -) - -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { - t.Helper() - connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(nil, &config.DatabaseOptions{ - ConnectionString: config.DataSource(connStr), - }, "", 4, 0, 0, "") - if err != nil { - t.Fatalf("failed to create new user db: %v", err) - } - return db, close -} - -func mustCreateEvent(t *testing.T, content string) *gomatrixserverlib.HeaderedEvent { - t.Helper() - ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(content), false, gomatrixserverlib.RoomVersionV10) - if err != nil { - t.Fatalf("failed to create event: %v", err) - } - return ev.Headered(gomatrixserverlib.RoomVersionV10) -} - -func Test_evaluatePushRules(t *testing.T) { - ctx := context.Background() - - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) - defer close() - consumer := OutputStreamEventConsumer{db: db} - - testCases := []struct { - name string - eventContent string - wantAction pushrules.ActionKind - wantActions []*pushrules.Action - wantNotify bool - }{ - { - name: "m.receipt doesn't notify", - eventContent: `{"type":"m.receipt"}`, - wantAction: pushrules.UnknownAction, - wantActions: nil, - }, - { - name: "m.reaction doesn't notify", - eventContent: `{"type":"m.reaction"}`, - wantAction: pushrules.DontNotifyAction, - wantActions: []*pushrules.Action{ - { - Kind: pushrules.DontNotifyAction, - }, - }, - }, - { - name: "m.room.message notifies", - eventContent: `{"type":"m.room.message"}`, - wantNotify: true, - wantAction: pushrules.NotifyAction, - wantActions: []*pushrules.Action{ - {Kind: pushrules.NotifyAction}, - { - Kind: pushrules.SetTweakAction, - Tweak: pushrules.HighlightTweak, - Value: false, - }, - }, - }, - { - name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`, - wantNotify: true, - wantAction: pushrules.NotifyAction, - wantActions: []*pushrules.Action{ - {Kind: pushrules.NotifyAction}, - { - Kind: pushrules.SetTweakAction, - Tweak: pushrules.SoundTweak, - Value: "default", - }, - { - Kind: pushrules.SetTweakAction, - Tweak: pushrules.HighlightTweak, - Value: true, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - actions, err := consumer.evaluatePushRules(ctx, mustCreateEvent(t, tc.eventContent), &localMembership{ - UserID: "@test:localhost", - Localpart: "test", - Domain: "localhost", - }, 10) - if err != nil { - t.Fatalf("failed to evaluate push rules: %v", err) - } - assert.Equal(t, tc.wantActions, actions) - gotAction, _, err := pushrules.ActionsToTweaks(actions) - if err != nil { - t.Fatalf("failed to get actions: %v", err) - } - if gotAction != tc.wantAction { - t.Fatalf("expected action to be '%s', got '%s'", tc.wantAction, gotAction) - } - // this is taken from `notifyLocal` - if tc.wantNotify && gotAction != pushrules.NotifyAction && gotAction != pushrules.CoalesceAction { - t.Fatalf("expected to notify but didn't") - } - }) - - } - }) -} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index dcbb7361..3e761a88 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -39,6 +40,7 @@ import ( "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" + userapiUtil "github.com/matrix-org/dendrite/userapi/util" ) type UserInternalAPI struct { @@ -51,6 +53,7 @@ type UserInternalAPI struct { AppServices []config.ApplicationService KeyAPI keyapi.UserKeyAPI RSAPI rsapi.UserRoomserverAPI + PgClient pushgateway.Client } func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { @@ -73,6 +76,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc ignoredUsers = &synctypes.IgnoredUsers{} _ = json.Unmarshal(req.AccountData, ignoredUsers) } + if req.DataType == "m.fully_read" { + if err := a.setFullyRead(ctx, req); err != nil { + return err + } + } if err := a.SyncProducer.SendAccountData(req.UserID, eventutil.AccountData{ RoomID: req.RoomID, Type: req.DataType, @@ -84,6 +92,44 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc return nil } +func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccountDataRequest) error { + var output eventutil.ReadMarkerJSON + + if err := json.Unmarshal(req.AccountData, &output); err != nil { + return err + } + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") + return nil + } + if domain != a.ServerName { + return nil + } + + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + if err != nil { + logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") + return err + } + + if err = a.SyncProducer.GetAndSendNotificationData(ctx, req.UserID, req.RoomID); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: GetAndSendNotificationData failed") + return err + } + + // nothing changed, no need to notify the push gateway + if !deleted { + return nil + } + + if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { + logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") + return err + } + return nil +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index 27cfc284..f556ea35 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" - "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/storage" ) type JetStreamPublisher interface { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index fbac463e..02efe7af 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -119,9 +119,9 @@ type ThreePID interface { } type Notification interface { - InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error) + InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index a27c1125..24a30b2f 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -20,12 +20,13 @@ import ( "encoding/json" "time" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" ) type notificationsStatements struct { @@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err @@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err @@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) - - if err != nil { - return 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - - return count, nil - } - return 0, rows.Err() +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) + return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { - rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) - - if err != nil { - return 0, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var total, highlight int64 - if err := rows.Scan(&total, &highlight); err != nil { - return 0, 0, err - } - - return total, highlight, nil - } - return 0, 0, rows.Err() +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) + return } diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 2eb379ae..6fb714fb 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -19,11 +19,12 @@ import ( "database/sql" "encoding/json" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/sirupsen/logrus" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers( pushers = append(pushers, pusher) } - logrus.Debugf("Database returned %d pushers", len(pushers)) + logrus.Tracef("Database returned %d pushers", len(pushers)) return pushers, rows.Err() } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index e32a442d..3ff299f1 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -700,13 +700,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) return err @@ -714,7 +714,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomI return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) return err diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index df826025..a35ec7be 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -20,12 +20,13 @@ import ( "encoding/json" "time" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - log "github.com/sirupsen/logrus" ) type notificationsStatements struct { @@ -110,7 +111,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -126,7 +127,7 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) if err != nil { return false, err @@ -140,7 +141,7 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) if err != nil { return false, err @@ -196,40 +197,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) - - if err != nil { - return 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var count int64 - if err := rows.Scan(&count); err != nil { - return 0, err - } - - return count, nil - } - return 0, rows.Err() +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) + return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { - rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) - - if err != nil { - return 0, 0, err - } - defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") - - if rows.Next() { - var total, highlight int64 - if err := rows.Scan(&total, &highlight); err != nil { - return 0, 0, err - } - - return total, highlight, nil - } - return 0, 0, rows.Err() +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) + return } diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index dba97c3d..4de0a9f0 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -19,11 +19,12 @@ import ( "database/sql" "encoding/json" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/sirupsen/logrus" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -136,7 +137,7 @@ func (s *pushersStatements) SelectPushers( pushers = append(pushers, pusher) } - logrus.Debugf("Database returned %d pushers", len(pushers)) + logrus.Tracef("Database returned %d pushers", len(pushers)) return pushers, rows.Err() } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index a2609733..ca7c1bfd 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -7,6 +7,11 @@ import ( "testing" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" @@ -14,10 +19,6 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/bcrypt" ) const loginTokenLifetime = time.Minute @@ -513,7 +514,7 @@ func Test_Notification(t *testing.T) { RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } - err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification) + err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 2fe95567..cc428799 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -105,9 +105,9 @@ type PusherTable interface { type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) diff --git a/userapi/userapi.go b/userapi/userapi.go index 23855a89..d26b4e19 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -81,16 +81,17 @@ func NewInternalAPI( KeyAPI: keyAPI, RSAPI: rsAPI, DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, + PgClient: pgClient, } - readConsumer := consumers.NewOutputReadUpdateConsumer( - base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, + receiptConsumer := consumers.NewOutputReceiptEventConsumer( + base.ProcessContext, cfg, js, db, syncProducer, pgClient, ) - if err := readConsumer.Start(); err != nil { - logrus.WithError(err).Panic("failed to start user API read update consumer") + if err := receiptConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API receipt consumer") } - eventConsumer := consumers.NewOutputStreamEventConsumer( + eventConsumer := consumers.NewOutputRoomEventConsumer( base.ProcessContext, cfg, js, db, pgClient, rsAPI, syncProducer, ) if err := eventConsumer.Start(); err != nil { -- cgit v1.2.3