diff options
author | Dan <dan@globekeeper.com> | 2022-03-03 13:40:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-03 11:40:53 +0000 |
commit | f05ce478f05dcaf650fbae68a39aaf5d9880a580 (patch) | |
tree | a6a47f77bba03ec7a05a8d98bea6791d47f3b48a /userapi | |
parent | 111f01ddc81d775dfdaab6e6a3a6afa6fa5608ea (diff) |
Implement Push Notifications (#1842)
* Add Pushserver component with Pushers API
Co-authored-by: Tommie Gannert <tommie@gannert.se>
Co-authored-by: Dan Peleg <dan@globekeeper.com>
* Wire Pushserver component
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
* Add PushGatewayClient.
The full event format is required for Sytest.
* Add a pushrules module.
* Change user API account creation to use the new pushrules module's defaults.
Introduces "scope" as required by client API, and some small field
tweaks to make some 61push Sytests pass.
* Add push rules query/put API in Pushserver.
This manipulates account data over User API, and fires sync messages
for changes. Those sync messages should, according to an existing TODO
in clientapi, be moved to userapi.
Forks clientapi/producers/syncapi.go to pushserver/ for later extension.
* Add clientapi routes for push rules to Pushserver.
A cleanup would be to move more of the name-splitting logic into
pushrules.go, to depollute routing.go.
* Output rooms.join.unread_notifications in /sync.
This is the read-side. Pushserver will be the write-side.
* Implement pushserver/storage for notifications.
* Use PushGatewayClient and the pushrules module in Pushserver's room consumer.
* Use one goroutine per user to avoid locking up the entire server for
one bad push gateway.
* Split pushing by format.
* Send one device per push. Sytest does not support coalescing
multiple devices into one push. Matches Synapse. Either we change
Sytest, or remove the group-by-url-and-format logic.
* Write OutputNotificationData from push server. Sync API is already
the consumer.
* Implement read receipt consumers in Pushserver.
Supports m.read and m.fully_read receipts.
* Add clientapi route for /unstable/notifications.
* Rename to UpsertPusher for clarity and handle pusher update
* Fix linter errors
* Ignore body.Close() error check
* Fix push server internal http wiring
* Add 40 newly passing 61push tests to whitelist
* Add next 12 newly passing 61push tests to whitelist
* Send notification data before notifying users in EDU server consumer
* NATS JetStream
* Goodbye sarama
* Fix `NewStreamTokenFromString`
* Consume on the correct topic for the roomserver
* Don't panic, NAK instead
* Move push notifications into the User API
* Don't set null values since that apparently causes Element upsetti
* Also set omitempty on conditions
* Fix bug so that we don't override the push rules unnecessarily
* Tweak defaults
* Update defaults
* More tweaks
* Move `/notifications` onto `r0`/`v3` mux
* User API will consume events and read/fully read markers from the sync API with stream positions, instead of consuming directly
Co-authored-by: Piotr Kozimor <p1996k@gmail.com>
Co-authored-by: Tommie Gannert <tommie@gannert.se>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'userapi')
-rw-r--r-- | userapi/api/api.go | 83 | ||||
-rw-r--r-- | userapi/api/api_trace.go | 30 | ||||
-rw-r--r-- | userapi/consumers/syncapi_readupdate.go | 136 | ||||
-rw-r--r-- | userapi/consumers/syncapi_streamevent.go | 588 | ||||
-rw-r--r-- | userapi/internal/api.go | 171 | ||||
-rw-r--r-- | userapi/inthttp/client.go | 61 | ||||
-rw-r--r-- | userapi/inthttp/server.go | 82 | ||||
-rw-r--r-- | userapi/producers/syncapi.go | 104 | ||||
-rw-r--r-- | userapi/storage/interface.go | 13 | ||||
-rw-r--r-- | userapi/storage/postgres/notifications_table.go | 219 | ||||
-rw-r--r-- | userapi/storage/postgres/pusher_table.go | 157 | ||||
-rw-r--r-- | userapi/storage/postgres/storage.go | 10 | ||||
-rw-r--r-- | userapi/storage/shared/storage.go | 109 | ||||
-rw-r--r-- | userapi/storage/sqlite3/notifications_table.go | 219 | ||||
-rw-r--r-- | userapi/storage/sqlite3/pusher_table.go | 157 | ||||
-rw-r--r-- | userapi/storage/sqlite3/storage.go | 10 | ||||
-rw-r--r-- | userapi/storage/tables/interface.go | 40 | ||||
-rw-r--r-- | userapi/userapi.go | 57 | ||||
-rw-r--r-- | userapi/userapi_test.go | 6 | ||||
-rw-r--r-- | userapi/util/devices.go | 100 | ||||
-rw-r--r-- | userapi/util/notify.go | 76 |
21 files changed, 2402 insertions, 26 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go index 2be662e5..e9cdbe01 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" ) // UserInternalAPI is the internal API for information about users and devices. @@ -28,6 +29,7 @@ type UserInternalAPI interface { LoginTokenInternalAPI InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error + PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error @@ -37,6 +39,10 @@ type UserInternalAPI interface { PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error + PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error + PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error + PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error + QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error @@ -45,6 +51,9 @@ type UserInternalAPI interface { QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error + QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error + QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error + QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error } type PerformKeyBackupRequest struct { @@ -424,3 +433,77 @@ const ( // AccountTypeAppService indicates this is an appservice account AccountTypeAppService AccountType = 4 ) + +type QueryPushersRequest struct { + Localpart string +} + +type QueryPushersResponse struct { + Pushers []Pusher `json:"pushers"` +} + +type PerformPusherSetRequest struct { + Pusher // Anonymous field because that's how clientapi unmarshals it. + Localpart string + Append bool `json:"append"` +} + +type PerformPusherDeletionRequest struct { + Localpart string + SessionID int64 +} + +// Pusher represents a push notification subscriber +type Pusher struct { + SessionID int64 `json:"session_id,omitempty"` + PushKey string `json:"pushkey"` + PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"` + Kind PusherKind `json:"kind"` + AppID string `json:"app_id"` + AppDisplayName string `json:"app_display_name"` + DeviceDisplayName string `json:"device_display_name"` + ProfileTag string `json:"profile_tag"` + Language string `json:"lang"` + Data map[string]interface{} `json:"data"` +} + +type PusherKind string + +const ( + EmailKind PusherKind = "email" + HTTPKind PusherKind = "http" +) + +type PerformPushRulesPutRequest struct { + UserID string `json:"user_id"` + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryPushRulesRequest struct { + UserID string `json:"user_id"` +} + +type QueryPushRulesResponse struct { + RuleSets *pushrules.AccountRuleSets `json:"rule_sets"` +} + +type QueryNotificationsRequest struct { + Localpart string `json:"localpart"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` +} + +type QueryNotificationsResponse struct { + NextToken string `json:"next_token"` + Notifications []*Notification `json:"notifications"` // Required. +} + +type Notification struct { + Actions []*pushrules.Action `json:"actions"` // Required. + Event gomatrixserverlib.ClientEvent `json:"event"` // Required. + ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional. + Read bool `json:"read"` // Required. + RoomID string `json:"room_id"` // Required. + TS gomatrixserverlib.Timestamp `json:"ts"` // Required. +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index aa069f40..9334f445 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error { + err := t.Impl.PerformPusherSet(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error { + err := t.Impl.PerformPusherDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error { + err := t.Impl.PerformPushRulesPut(ctx, req, res) + util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res)) + return err +} func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { t.Impl.QueryKeyBackup(ctx, req, res) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) @@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res)) return err } +func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error { + err := t.Impl.QueryPushers(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error { + err := t.Impl.QueryPushRules(ctx, req, res) + util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res)) + return err +} +func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error { + err := t.Impl.QueryNotifications(ctx, req, res) + util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res)) + return err +} func js(thing interface{}) string { b, err := json.Marshal(thing) diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go new file mode 100644 index 00000000..2e58020b --- /dev/null +++ b/userapi/consumers/syncapi_readupdate.go @@ -0,0 +1,136 @@ +package consumers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputReadUpdateConsumer struct { + ctx context.Context + cfg *config.UserAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + pgClient pushgateway.Client + ServerName gomatrixserverlib.ServerName + topic string + userAPI uapi.UserInternalAPI + syncProducer *producers.SyncAPI +} + +func NewOutputReadUpdateConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI uapi.UserInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputReadUpdateConsumer { + return &OutputReadUpdateConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + ServerName: cfg.Matrix.ServerName, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate), + pgClient: pgClient, + userAPI: userAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputReadUpdateConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var read types.ReadUpdate + if err := json.Unmarshal(msg.Data, &read); err != nil { + log.WithError(err).Error("userapi clientapi consumer: message parse failure") + return true + } + if read.FullyRead == 0 && read.Read == 0 { + return true + } + + userID := string(msg.Header.Get(jetstream.UserID)) + roomID := string(msg.Header.Get(jetstream.RoomID)) + + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + log.WithError(err).Error("userapi clientapi consumer: SplitID failure") + return true + } + if domain != s.ServerName { + log.Error("userapi clientapi consumer: not a local user") + return true + } + + log := log.WithFields(log.Fields{ + "room_id": roomID, + "user_id": userID, + }) + log.Tracef("Received read update from sync API: %#v", read) + + if read.Read > 0 { + updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true) + if err != nil { + log.WithError(err).Error("userapi EDU consumer") + return false + } + + if updated { + if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil { + log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed") + return false + } + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") + return false + } + } + } + + if read.FullyRead > 0 { + deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead)) + if err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed") + return false + } + + if deleted { + if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed") + return false + } + + if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil { + log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed") + return false + } + } + } + + return true +} diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go new file mode 100644 index 00000000..11081327 --- /dev/null +++ b/userapi/consumers/syncapi_streamevent.go @@ -0,0 +1,588 @@ +package consumers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/util" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type OutputStreamEventConsumer struct { + ctx context.Context + cfg *config.UserAPI + userAPI api.UserInternalAPI + rsAPI rsapi.RoomserverInternalAPI + jetstream nats.JetStreamContext + durable string + db storage.Database + topic string + pgClient pushgateway.Client + syncProducer *producers.SyncAPI +} + +func NewOutputStreamEventConsumer( + process *process.ProcessContext, + cfg *config.UserAPI, + js nats.JetStreamContext, + store storage.Database, + pgClient pushgateway.Client, + userAPI api.UserInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, + syncProducer *producers.SyncAPI, +) *OutputStreamEventConsumer { + return &OutputStreamEventConsumer{ + ctx: process.Context(), + cfg: cfg, + jetstream: js, + db: store, + durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"), + topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent), + pgClient: pgClient, + userAPI: userAPI, + rsAPI: rsAPI, + syncProducer: syncProducer, + } +} + +func (s *OutputStreamEventConsumer) Start() error { + if err := jetstream.JetStreamConsumer( + s.ctx, s.jetstream, s.topic, s.durable, s.onMessage, + nats.DeliverAll(), nats.ManualAck(), + ); err != nil { + return err + } + return nil +} + +func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool { + var output types.StreamedEvent + output.Event = &gomatrixserverlib.HeaderedEvent{} + if err := json.Unmarshal(msg.Data, &output); err != nil { + log.WithError(err).Errorf("userapi consumer: message parse failure") + return true + } + if output.Event.Event == nil { + log.Errorf("userapi consumer: expected event") + return true + } + + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + "event_type": output.Event.Type(), + "stream_pos": output.StreamPosition, + }).Tracef("Received message from sync API: %#v", output) + + if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil { + log.WithFields(log.Fields{ + "event_id": output.Event.EventID(), + }).WithError(err).Errorf("userapi consumer: process room event failure") + } + + return true +} + +func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error { + members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) + if err != nil { + return fmt.Errorf("s.localRoomMembers: %w", err) + } + + if event.Type() == gomatrixserverlib.MRoomMember { + cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) + var member *localMembership + member, err = newLocalMembership(&cevent) + if err != nil { + return fmt.Errorf("newLocalMembership: %w", err) + } + if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName { + // localRoomMembers only adds joined members. An invite + // should also be pushed to the target user. + members = append(members, member) + } + } + + // TODO: run in parallel with localRoomMembers. + roomName, err := s.roomName(ctx, event) + if err != nil { + return fmt.Errorf("s.roomName: %w", err) + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "num_members": len(members), + "room_size": roomSize, + }).Tracef("Notifying members") + + // Notification.UserIsTarget is a per-member field, so we + // cannot group all users in a single request. + // + // TODO: does it have to be set? It's not required, and + // removing it means we can send all notifications to + // e.g. Element's Push gateway in one go. + for _, mem := range members { + if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil { + log.WithFields(log.Fields{ + "localpart": mem.Localpart, + }).WithError(err).Debugf("Unable to push to local user") + continue + } + } + + return nil +} + +type localMembership struct { + gomatrixserverlib.MemberContent + UserID string + Localpart string + Domain gomatrixserverlib.ServerName +} + +func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) { + if event.StateKey == nil { + return nil, fmt.Errorf("missing state_key") + } + + var member localMembership + if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil { + return nil, err + } + + localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + return nil, err + } + + member.UserID = *event.StateKey + member.Localpart = localpart + member.Domain = domain + return &member, nil +} + +// localRoomMembers fetches the current local members of a room, and +// the total number of members. +func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { + req := &rsapi.QueryMembershipsForRoomRequest{ + RoomID: roomID, + JoinedOnly: true, + } + var res rsapi.QueryMembershipsForRoomResponse + + // XXX: This could potentially race if the state for the event is not known yet + // e.g. the event came over federation but we do not have the full state persisted. + if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { + return nil, 0, err + } + + var members []*localMembership + var ntotal int + for _, event := range res.JoinEvents { + member, err := newLocalMembership(&event) + if err != nil { + log.WithError(err).Errorf("Parsing MemberContent") + continue + } + if member.Membership != gomatrixserverlib.Join { + continue + } + if member.Domain != s.cfg.Matrix.ServerName { + continue + } + + ntotal++ + members = append(members, member) + } + + return members, ntotal, nil +} + +// roomName returns the name in the event (if type==m.room.name), or +// looks it up in roomserver. If there is no name, +// m.room.canonical_alias is consulted. Returns an empty string if the +// room has no name. +func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) { + if event.Type() == gomatrixserverlib.MRoomName { + name, err := unmarshalRoomName(event) + if err != nil { + return "", err + } + + if name != "" { + return name, nil + } + } + + req := &rsapi.QueryCurrentStateRequest{ + RoomID: event.RoomID(), + StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple}, + } + var res rsapi.QueryCurrentStateResponse + + if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil { + return "", nil + } + + if eventS := res.StateEvents[roomNameTuple]; eventS != nil { + return unmarshalRoomName(eventS) + } + + if event.Type() == gomatrixserverlib.MRoomCanonicalAlias { + alias, err := unmarshalCanonicalAlias(event) + if err != nil { + return "", err + } + + if alias != "" { + return alias, nil + } + } + + if event = res.StateEvents[canonicalAliasTuple]; event != nil { + return unmarshalCanonicalAlias(event) + } + + return "", nil +} + +var ( + canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias} + roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName} +) + +func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var nc eventutil.NameContent + if err := json.Unmarshal(event.Content(), &nc); err != nil { + return "", fmt.Errorf("unmarshaling NameContent: %w", err) + } + + return nc.Name, nil +} + +func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) { + var cac eventutil.CanonicalAliasContent + if err := json.Unmarshal(event.Content(), &cac); err != nil { + return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err) + } + + return cac.Alias, nil +} + +// notifyLocal finds the right push actions for a local user, given an event. +func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error { + actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) + if err != nil { + return err + } + a, tweaks, err := pushrules.ActionsToTweaks(actions) + if err != nil { + return err + } + // TODO: support coalescing. + if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + }).Tracef("Push rule evaluation rejected the event") + return nil + } + + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + if err != nil { + return err + } + + n := &api.Notification{ + Actions: actions, + // UNSPEC: the spec doesn't say this is a ClientEvent, but the + // fields seem to match. room_id should be missing, which + // matches the behaviour of FormatSync. + Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync), + // TODO: this is per-device, but it's not part of the primary + // key. So inserting one notification per profile tag doesn't + // make sense. What is this supposed to be? Sytests require it + // to "work", but they only use a single device. + ProfileTag: profileTag, + RoomID: event.RoomID(), + TS: gomatrixserverlib.AsTimestamp(time.Now()), + } + if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil { + return err + } + + if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { + return err + } + + // We do this after InsertNotification. Thus, this should always return >=1. + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "num_urls": len(devicesByURLAndFormat), + "num_unread": userNumUnreadNotifs, + }).Tracef("Notifying single member") + + // Push gateways are out of our control, and we cannot risk + // looking up the server on a misbehaving push gateway. Each user + // receives a goroutine now that all internal API calls have been + // made. + // + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var rejected []*pushgateway.Device + for url, fmts := range devicesByURLAndFormat { + for format, devices := range fmts { + // TODO: support "email". + if !strings.HasPrefix(url, "http") { + continue + } + + // UNSPEC: the specification suggests there can be + // more than one device per request. There is at least + // one Sytest that expects one HTTP request per + // device, rather than per URL. For now, we must + // notify each one separately. + for _, dev := range devices { + rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs)) + if err != nil { + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "localpart": mem.Localpart, + }).WithError(err).Errorf("Unable to notify HTTP pusher") + continue + } + rejected = append(rejected, rej...) + } + } + } + + if len(rejected) > 0 { + s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + } + }() + + return nil +} + +// evaluatePushRules fetches and evaluates the push rules of a local +// user. Returns actions (including dont_notify). +func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { + if event.Sender() == mem.UserID { + // SPEC: Homeservers MUST NOT notify the Push Gateway for + // events that the user has sent themselves. + return nil, nil + } + + var res api.QueryPushRulesResponse + if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil { + return nil, err + } + + ec := &ruleSetEvalContext{ + ctx: ctx, + rsAPI: s.rsAPI, + mem: mem, + roomID: event.RoomID(), + roomSize: roomSize, + } + eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global) + rule, err := eval.MatchEvent(event.Event) + if err != nil { + return nil, err + } + if rule == nil { + // SPEC: If no rules match an event, the homeserver MUST NOT + // notify the Push Gateway for that event. + return nil, err + } + + log.WithFields(log.Fields{ + "event_id": event.EventID(), + "room_id": event.RoomID(), + "localpart": mem.Localpart, + "rule_id": rule.RuleID, + }).Tracef("Matched a push rule") + + return rule.Actions, nil +} + +type ruleSetEvalContext struct { + ctx context.Context + rsAPI rsapi.RoomserverInternalAPI + mem *localMembership + roomID string + roomSize int +} + +func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName } + +func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } + +func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { + req := &rsapi.QueryLatestEventsAndStateRequest{ + RoomID: rse.roomID, + StateToFetch: []gomatrixserverlib.StateKeyTuple{ + {EventType: gomatrixserverlib.MRoomPowerLevels}, + }, + } + var res rsapi.QueryLatestEventsAndStateResponse + if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil { + return false, err + } + for _, ev := range res.StateEvents { + if ev.Type() != gomatrixserverlib.MRoomPowerLevels { + continue + } + + plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event) + if err != nil { + return false, err + } + return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + } + return true, nil +} + +// localPushDevices pushes to the configured devices of a local +// user. The map keys are [url][format]. +func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) + if err != nil { + return nil, "", err + } + + var profileTag string + devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices)) + for _, pusherDevice := range pusherDevices { + if profileTag == "" { + profileTag = pusherDevice.Pusher.ProfileTag + } + + url := pusherDevice.URL + if devicesByURL[url] == nil { + devicesByURL[url] = make(map[string][]*pushgateway.Device, 2) + } + devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device) + } + + return devicesByURL, profileTag, nil +} + +// notifyHTTP performs a notificatation to a Push Gateway. +func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) { + logger := log.WithFields(log.Fields{ + "event_id": event.EventID(), + "url": url, + "localpart": localpart, + "num_devices": len(devices), + }) + + var req pushgateway.NotifyRequest + switch format { + case "event_id_only": + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{}, + Devices: devices, + EventID: event.EventID(), + RoomID: event.RoomID(), + }, + } + + default: + req = pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Content: event.Content(), + Counts: &pushgateway.Counts{ + Unread: userNumUnreadNotifs, + }, + Devices: devices, + EventID: event.EventID(), + ID: event.EventID(), + RoomID: event.RoomID(), + RoomName: roomName, + Sender: event.Sender(), + Type: event.Type(), + }, + } + if mem, err := event.Membership(); err == nil { + req.Notification.Membership = mem + } + if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) { + req.Notification.UserIsTarget = true + } + } + + logger.Debugf("Notifying push gateway %s", url) + var res pushgateway.NotifyResponse + if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil { + logger.WithError(err).Errorf("Failed to notify push gateway %s", url) + return nil, err + } + logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result") + + if len(res.Rejected) == 0 { + return nil, nil + } + + devMap := make(map[string]*pushgateway.Device, len(devices)) + for _, d := range devices { + devMap[d.PushKey] = d + } + rejected := make([]*pushgateway.Device, 0, len(res.Rejected)) + for _, pushKey := range res.Rejected { + d := devMap[pushKey] + if d != nil { + rejected = append(rejected, d) + } + } + + return rejected, nil +} + +// deleteRejectedPushers deletes the pushers associated with the given devices. +func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": devices[0].AppID, + "num_devices": len(devices), + }).Warnf("Deleting pushers rejected by the HTTP push gateway") + + for _, d := range devices { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + }).WithError(err).Errorf("Unable to delete rejected pusher") + } + } +} diff --git a/userapi/internal/api.go b/userapi/internal/api.go index f54cc613..7a42fc60 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -20,6 +20,8 @@ import ( "encoding/json" "errors" "fmt" + "strconv" + "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -27,16 +29,22 @@ import ( "github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type UserInternalAPI struct { - DB storage.Database - ServerName gomatrixserverlib.ServerName + DB storage.Database + SyncProducer *producers.SyncAPI + + DisableTLSValidation bool + ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.KeyInternalAPI @@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB } res.Keys = result } + +func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + if req.Limit == 0 || req.Limit > 1000 { + req.Limit = 1000 + } + + var fromID int64 + var err error + if req.From != "" { + fromID, err = strconv.ParseInt(req.From, 10, 64) + if err != nil { + return fmt.Errorf("QueryNotifications: parsing 'from': %w", err) + } + } + var filter tables.NotificationFilter = tables.AllNotifications + if req.Only == "highlight" { + filter = tables.HighlightNotifications + } + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) + if err != nil { + return err + } + if notifs == nil { + // This ensures empty is JSON-encoded as [] instead of null. + notifs = []*api.Notification{} + } + res.Notifications = notifs + if lastID >= 0 { + res.NextToken = strconv.FormatInt(lastID+1, 10) + } + return nil +} + +func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "localpart": req.Localpart, + "pushkey": req.Pusher.PushKey, + "display_name": req.Pusher.AppDisplayName, + }).Info("PerformPusherCreation") + if !req.Append { + err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey) + if err != nil { + return err + } + } + if req.Pusher.Kind == "" { + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) + } + if req.Pusher.PushKeyTS == 0 { + req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now()) + } + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) +} + +func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + pushers, err := a.DB.GetPushers(ctx, req.Localpart) + if err != nil { + return err + } + for i := range pushers { + logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) + if pushers[i].SessionID != req.SessionID { + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) + if err != nil { + return err + } + } + } + return nil +} + +func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + var err error + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) + return err +} + +func (a *UserInternalAPI) PerformPushRulesPut( + ctx context.Context, + req *api.PerformPushRulesPutRequest, + _ *struct{}, +) error { + bs, err := json.Marshal(&req.RuleSets) + if err != nil { + return err + } + userReq := api.InputAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + AccountData: json.RawMessage(bs), + } + var userRes api.InputAccountDataResponse // empty + if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + + if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed") + } + + return nil +} + +func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + userReq := api.QueryAccountDataRequest{ + UserID: req.UserID, + DataType: pushRulesAccountDataType, + } + var userRes api.QueryAccountDataResponse + if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil { + return err + } + bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType] + if ok { + // Legacy Dendrite users will have completely empty push rules, so we should + // detect that situation and set some defaults. + var rules struct { + G struct { + Content []json.RawMessage `json:"content"` + Override []json.RawMessage `json:"override"` + Room []json.RawMessage `json:"room"` + Sender []json.RawMessage `json:"sender"` + Underride []json.RawMessage `json:"underride"` + } `json:"global"` + } + if err := json.Unmarshal([]byte(bs), &rules); err == nil { + count := len(rules.G.Content) + len(rules.G.Override) + + len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride) + ok = count > 0 + } + } + if !ok { + // If we didn't find any default push rules then we should just generate some + // fresh ones. + localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) + } + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return fmt.Errorf("failed to marshal default push rules: %w", err) + } + if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil { + return fmt.Errorf("failed to save default push rules: %w", err) + } + res.RuleSets = pushRuleSets + return nil + } + var data pushrules.AccountRuleSets + if err := json.Unmarshal([]byte(bs), &data); err != nil { + util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed") + return err + } + res.RuleSets = &data + return nil +} + +const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 1599d463..8ec649ad 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -37,6 +37,9 @@ const ( PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformKeyBackupPath = "/userapi/performKeyBackup" + PerformPusherSetPath = "/pushserver/performPusherSet" + PerformPusherDeletionPath = "/pushserver/performPusherDeletion" + PerformPushRulesPutPath = "/pushserver/performPushRulesPut" QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryProfilePath = "/userapi/queryProfile" @@ -46,6 +49,9 @@ const ( QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QuerySearchProfilesPath = "/userapi/querySearchProfiles" QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" + QueryPushersPath = "/pushserver/queryPushers" + QueryPushRulesPath = "/pushserver/queryPushRules" + QueryNotificationsPath = "/pushserver/queryNotifications" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -249,3 +255,58 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query res.Error = err.Error() } } + +func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications") + defer span.Finish() + + return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res) +} + +func (h *httpUserInternalAPI) PerformPusherSet( + ctx context.Context, + request *api.PerformPusherSetRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherSetPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers") + defer span.Finish() + + apiURL := h.apiURL + QueryPushersPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + +func (h *httpUserInternalAPI) PerformPushRulesPut( + ctx context.Context, + request *api.PerformPushRulesPutRequest, + response *struct{}, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut") + defer span.Finish() + + apiURL := h.apiURL + PerformPushRulesPutPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules") + defer span.Finish() + + apiURL := h.apiURL + QueryPushRulesPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index d00ee042..526f9957 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryNotificationsPath, + httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse { + var request api.QueryNotificationsRequest + var response api.QueryNotificationsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryNotifications(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPusherSetPath, + httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherSetRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformPusherDeletionPath, + httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherDeletionRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushersPath, + httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse { + request := api.QueryPushersRequest{} + response := api.QueryPushersResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushers(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(PerformPushRulesPutPath, + httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse { + request := api.PerformPushRulesPutRequest{} + response := struct{}{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + + internalAPIMux.Handle(QueryPushRulesPath, + httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse { + request := api.QueryPushRulesRequest{} + response := api.QueryPushRulesResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryPushRules(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go new file mode 100644 index 00000000..4a206f33 --- /dev/null +++ b/userapi/producers/syncapi.go @@ -0,0 +1,104 @@ +package producers + +import ( + "context" + "encoding/json" + + "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + log "github.com/sirupsen/logrus" +) + +type JetStreamPublisher interface { + PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) +} + +// SyncAPI produces messages for the Sync API server to consume. +type SyncAPI struct { + db storage.Database + producer JetStreamPublisher + clientDataTopic string + notificationDataTopic string +} + +func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI { + return &SyncAPI{ + db: db, + producer: js, + clientDataTopic: clientDataTopic, + notificationDataTopic: notificationDataTopic, + } +} + +// SendAccountData sends account data to the Sync API server. +func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error { + m := &nats.Msg{ + Subject: p.clientDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(eventutil.AccountData{ + RoomID: roomID, + Type: dataType, + }) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": roomID, + "data_type": dataType, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} + +// GetAndSendNotificationData reads the database and sends data about unread +// notifications to the Sync API server. +func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + + ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) + if err != nil { + return err + } + + return p.sendNotificationData(userID, &eventutil.NotificationData{ + RoomID: roomID, + UnreadHighlightCount: int(nhighlight), + UnreadNotificationCount: int(ntotal), + }) +} + +// sendNotificationData sends data about unread notifications to the Sync API server. +func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error { + m := &nats.Msg{ + Subject: p.notificationDataTopic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, userID) + + var err error + m.Data, err = json.Marshal(data) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "user_id": userID, + "room_id": data.RoomID, + }).Tracef("Producing to topic '%s'", p.clientDataTopic) + + _, err = p.producer.PublishMsg(m) + return err +} diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index a131dac4..6d22fea9 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type Database interface { @@ -89,6 +90,18 @@ type Database interface { // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) + + InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + + UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error + GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + RemovePushers(ctx context.Context, appid, pushkey string) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go new file mode 100644 index 00000000..7bcc0f9c --- /dev/null +++ b/userapi/storage/postgres/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id BIGSERIAL PRIMARY KEY, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go new file mode 100644 index 00000000..670dc916 --- /dev/null +++ b/userapi/storage/postgres/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id BIGSERIAL PRIMARY KEY, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index ac5c59b8..c74a999f 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) } + pusherTable, err := NewPostgresPusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewPostgresNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewDummyWriter(), diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 5f1f9500..a58974b4 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -29,6 +29,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -47,6 +48,8 @@ type Database struct { KeyBackupVersions tables.KeyBackupVersionTable Devices tables.DevicesTable LoginTokens tables.LoginTokenTable + Notifications tables.NotificationTable + Pushers tables.PusherTable LoginTokenLifetime time.Duration ServerName gomatrixserverlib.ServerName BcryptCost int @@ -160,15 +163,12 @@ func (d *Database) createAccount( if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return nil, err + } + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { return nil, err } return account, nil @@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { return d.LoginTokens.SelectLoginToken(ctx, token) } + +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + }) +} + +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + return err + }) + return +} + +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + return err + }) + return +} + +func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +} + +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, filter) +} + +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +} + +func (d *Database) UpsertPusher( + ctx context.Context, p api.Pusher, localpart string, +) error { + data, err := json.Marshal(p.Data) + if err != nil { + return err + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Pushers.InsertPusher( + ctx, txn, + p.SessionID, + p.PushKey, + p.PushKeyTS, + p.Kind, + p.AppID, + p.AppDisplayName, + p.DeviceDisplayName, + p.ProfileTag, + p.Language, + string(data), + localpart) + }) +} + +// GetPushers returns the pushers matching the given localpart. +func (d *Database) GetPushers( + ctx context.Context, localpart string, +) ([]api.Pusher, error) { + return d.Pushers.SelectPushers(ctx, nil, localpart) +} + +// RemovePusher deletes one pusher +// Invoked when `append` is true and `kind` is null in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePusher( + ctx context.Context, appid, pushkey, localpart string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + if err == sql.ErrNoRows { + return nil + } + return err + }) +} + +// RemovePushers deletes all pushers that match given App Id and Push Key pair. +// Invoked when `append` parameter is false in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePushers( + ctx context.Context, appid, pushkey string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) + }) +} diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go new file mode 100644 index 00000000..fcfb1aad --- /dev/null +++ b/userapi/storage/sqlite3/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go new file mode 100644 index 00000000..e718792e --- /dev/null +++ b/userapi/storage/sqlite3/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 98c24497..b5bb96c4 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) } + pusherTable, err := NewSQLitePusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewSQLiteNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewExclusiveWriter(), diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 12939ced..815e5119 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -93,3 +94,42 @@ type ThreePIDTable interface { InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } + +type PusherTable interface { + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error +} + +type NotificationTable interface { + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) +} + +type NotificationFilter uint32 + +const ( + // HighlightNotifications returns notifications that had a + // "highlight" tweak assigned to them from evaluating push rules. + HighlightNotifications NotificationFilter = 1 << iota + + // NonHighlightNotifications returns notifications that don't + // match HighlightNotifications. + NonHighlightNotifications + + // NoNotifications is a filter to exclude all types of + // notifications. It's useful as a zero value, but isn't likely to + // be used in a call to Notifications.Select*. + NoNotifications NotificationFilter = 0 + + // AllNotifications is a filter to include all types of + // notifications in Notifications.Select*. Note that PostgreSQL + // balks if this doesn't fit in INTEGER, even though we use + // uint32. + AllNotifications NotificationFilter = (1 << 31) - 1 +) diff --git a/userapi/userapi.go b/userapi/userapi.go index 4a5793ab..2382e951 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -18,11 +18,17 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/pushgateway" keyapi "github.com/matrix-org/dendrite/keyserver/api" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/consumers" "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" + "github.com/matrix-org/dendrite/userapi/producers" "github.com/matrix-org/dendrite/userapi/storage" "github.com/sirupsen/logrus" ) @@ -36,26 +42,49 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( - accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI, + appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, + rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client, ) api.UserInternalAPI { db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } - return newInternalAPI(db, cfg, appServices, keyAPI) -} + js := jetstream.Prepare(&cfg.Matrix.JetStream) -func newInternalAPI( - db storage.Database, - cfg *config.UserAPI, - appServices []config.ApplicationService, - keyAPI keyapi.KeyInternalAPI, -) api.UserInternalAPI { - return &internal.UserInternalAPI{ - DB: db, - ServerName: cfg.Matrix.ServerName, - AppServices: appServices, - KeyAPI: keyAPI, + syncProducer := producers.NewSyncAPI( + db, js, + // TODO: user API should handle syncs for account data. Right now, + // it's handled by clientapi, and hence uses its topic. When user + // API handles it for all account data, we can remove it from + // here. + cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), + cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData), + ) + + userAPI := &internal.UserInternalAPI{ + DB: db, + SyncProducer: syncProducer, + ServerName: cfg.Matrix.ServerName, + AppServices: appServices, + KeyAPI: keyAPI, + DisableTLSValidation: cfg.PushGatewayDisableTLSValidation, + } + + readConsumer := consumers.NewOutputReadUpdateConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer, + ) + if err := readConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API read update consumer") + } + + eventConsumer := consumers.NewOutputStreamEventConsumer( + base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer, + ) + if err := eventConsumer.Start(); err != nil { + logrus.WithError(err).Panic("failed to start user API streamed event consumer") } + + return userAPI } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 4214c07f..25319c4b 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -30,6 +30,7 @@ import ( "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s }, } - return newInternalAPI(accountDB, cfg, nil, nil), accountDB + return &internal.UserInternalAPI{ + DB: accountDB, + ServerName: cfg.Matrix.ServerName, + }, accountDB } func TestQueryProfile(t *testing.T) { diff --git a/userapi/util/devices.go b/userapi/util/devices.go new file mode 100644 index 00000000..cbf3bd28 --- /dev/null +++ b/userapi/util/devices.go @@ -0,0 +1,100 @@ +package util + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + log "github.com/sirupsen/logrus" +) + +type PusherDevice struct { + Device pushgateway.Device + Pusher *api.Pusher + URL string + Format string +} + +// GetPushDevices pushes to the configured devices of a local user. +func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { + pushers, err := db.GetPushers(ctx, localpart) + if err != nil { + return nil, err + } + + devices := make([]*PusherDevice, 0, len(pushers)) + for _, pusher := range pushers { + var url, format string + data := pusher.Data + switch pusher.Kind { + case api.EmailKind: + url = "mailto:" + + case api.HTTPKind: + // TODO: The spec says only event_id_only is supported, + // but Sytests assume "" means "full notification". + fmtIface := pusher.Data["format"] + var ok bool + format, ok = fmtIface.(string) + if ok && format != "event_id_only" { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("Only data.format event_id_only or empty is supported") + continue + } + + urlIface := pusher.Data["url"] + url, ok = urlIface.(string) + if !ok { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + }).Errorf("No data.url configured for HTTP Pusher") + continue + } + data = mapWithout(data, "url") + + default: + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id": pusher.AppID, + "kind": pusher.Kind, + }).Errorf("Unhandled pusher kind") + continue + } + + devices = append(devices, &PusherDevice{ + Device: pushgateway.Device{ + AppID: pusher.AppID, + Data: data, + PushKey: pusher.PushKey, + PushKeyTS: pusher.PushKeyTS, + Tweaks: tweaks, + }, + Pusher: &pusher, + URL: url, + Format: format, + }) + } + + return devices, nil +} + +// mapWithout returns a shallow copy of the map, without the given +// key. Returns nil if the resulting map is empty. +func mapWithout(m map[string]interface{}, key string) map[string]interface{} { + ret := make(map[string]interface{}, len(m)) + for k, v := range m { + // The specification says we do not send "url". + if k == key { + continue + } + ret[k] = v + } + if len(ret) == 0 { + return nil + } + return ret +} diff --git a/userapi/util/notify.go b/userapi/util/notify.go new file mode 100644 index 00000000..ff206bd3 --- /dev/null +++ b/userapi/util/notify.go @@ -0,0 +1,76 @@ +package util + +import ( + "context" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/dendrite/userapi/storage/tables" + log "github.com/sirupsen/logrus" +) + +// NotifyUserCountsAsync sends notifications to a local user's +// notification destinations. Database lookups run synchronously, but +// a single goroutine is started when talking to the Push +// gateways. There is no way to know when the background goroutine has +// finished. +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { + pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) + if err != nil { + return err + } + + if len(pusherDevices) == 0 { + return nil + } + + userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) + if err != nil { + return err + } + + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevices[0].Device.AppID, + "pushkey": pusherDevices[0].Device.PushKey, + }).Tracef("Notifying HTTP push gateway about notification counts") + + // TODO: think about bounding this to one per user, and what + // ordering guarantees we must provide. + go func() { + // This background processing cannot be tied to a request. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // TODO: we could batch all devices with the same URL, but + // Sytest requires consumers/roomserver.go to do it + // one-by-one, so we do the same here. + for _, pusherDevice := range pusherDevices { + // TODO: support "email". + if !strings.HasPrefix(pusherDevice.URL, "http") { + continue + } + + req := pushgateway.NotifyRequest{ + Notification: pushgateway.Notification{ + Counts: &pushgateway.Counts{ + Unread: int(userNumUnreadNotifs), + }, + Devices: []*pushgateway.Device{&pusherDevice.Device}, + }, + } + if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil { + log.WithFields(log.Fields{ + "localpart": localpart, + "app_id0": pusherDevice.Device.AppID, + "pushkey": pusherDevice.Device.PushKey, + }).WithError(err).Error("HTTP push gateway request failed") + return + } + } + }() + + return nil +} |