aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go83
-rw-r--r--userapi/api/api_trace.go30
-rw-r--r--userapi/consumers/syncapi_readupdate.go136
-rw-r--r--userapi/consumers/syncapi_streamevent.go588
-rw-r--r--userapi/internal/api.go171
-rw-r--r--userapi/inthttp/client.go61
-rw-r--r--userapi/inthttp/server.go82
-rw-r--r--userapi/producers/syncapi.go104
-rw-r--r--userapi/storage/interface.go13
-rw-r--r--userapi/storage/postgres/notifications_table.go219
-rw-r--r--userapi/storage/postgres/pusher_table.go157
-rw-r--r--userapi/storage/postgres/storage.go10
-rw-r--r--userapi/storage/shared/storage.go109
-rw-r--r--userapi/storage/sqlite3/notifications_table.go219
-rw-r--r--userapi/storage/sqlite3/pusher_table.go157
-rw-r--r--userapi/storage/sqlite3/storage.go10
-rw-r--r--userapi/storage/tables/interface.go40
-rw-r--r--userapi/userapi.go57
-rw-r--r--userapi/userapi_test.go6
-rw-r--r--userapi/util/devices.go100
-rw-r--r--userapi/util/notify.go76
21 files changed, 2402 insertions, 26 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 2be662e5..e9cdbe01 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/internal/pushrules"
)
// UserInternalAPI is the internal API for information about users and devices.
@@ -28,6 +29,7 @@ type UserInternalAPI interface {
LoginTokenInternalAPI
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
+
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
@@ -37,6 +39,10 @@ type UserInternalAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
+ PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
+ PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
+ PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
+
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
@@ -45,6 +51,9 @@ type UserInternalAPI interface {
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
+ QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
+ QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
+ QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
}
type PerformKeyBackupRequest struct {
@@ -424,3 +433,77 @@ const (
// AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4
)
+
+type QueryPushersRequest struct {
+ Localpart string
+}
+
+type QueryPushersResponse struct {
+ Pushers []Pusher `json:"pushers"`
+}
+
+type PerformPusherSetRequest struct {
+ Pusher // Anonymous field because that's how clientapi unmarshals it.
+ Localpart string
+ Append bool `json:"append"`
+}
+
+type PerformPusherDeletionRequest struct {
+ Localpart string
+ SessionID int64
+}
+
+// Pusher represents a push notification subscriber
+type Pusher struct {
+ SessionID int64 `json:"session_id,omitempty"`
+ PushKey string `json:"pushkey"`
+ PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
+ Kind PusherKind `json:"kind"`
+ AppID string `json:"app_id"`
+ AppDisplayName string `json:"app_display_name"`
+ DeviceDisplayName string `json:"device_display_name"`
+ ProfileTag string `json:"profile_tag"`
+ Language string `json:"lang"`
+ Data map[string]interface{} `json:"data"`
+}
+
+type PusherKind string
+
+const (
+ EmailKind PusherKind = "email"
+ HTTPKind PusherKind = "http"
+)
+
+type PerformPushRulesPutRequest struct {
+ UserID string `json:"user_id"`
+ RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
+}
+
+type QueryPushRulesRequest struct {
+ UserID string `json:"user_id"`
+}
+
+type QueryPushRulesResponse struct {
+ RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
+}
+
+type QueryNotificationsRequest struct {
+ Localpart string `json:"localpart"` // Required.
+ From string `json:"from,omitempty"`
+ Limit int `json:"limit,omitempty"`
+ Only string `json:"only,omitempty"`
+}
+
+type QueryNotificationsResponse struct {
+ NextToken string `json:"next_token"`
+ Notifications []*Notification `json:"notifications"` // Required.
+}
+
+type Notification struct {
+ Actions []*pushrules.Action `json:"actions"` // Required.
+ Event gomatrixserverlib.ClientEvent `json:"event"` // Required.
+ ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional.
+ Read bool `json:"read"` // Required.
+ RoomID string `json:"room_id"` // Required.
+ TS gomatrixserverlib.Timestamp `json:"ts"` // Required.
+}
diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go
index aa069f40..9334f445 100644
--- a/userapi/api/api_trace.go
+++ b/userapi/api/api_trace.go
@@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor
util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res))
return err
}
+func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error {
+ err := t.Impl.PerformPusherSet(ctx, req, res)
+ util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error {
+ err := t.Impl.PerformPusherDeletion(ctx, req, res)
+ util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error {
+ err := t.Impl.PerformPushRulesPut(ctx, req, res)
+ util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
+ return err
+}
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
@@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO
util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res))
return err
}
+func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error {
+ err := t.Impl.QueryPushers(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error {
+ err := t.Impl.QueryPushRules(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error {
+ err := t.Impl.QueryNotifications(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res))
+ return err
+}
func js(thing interface{}) string {
b, err := json.Marshal(thing)
diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go
new file mode 100644
index 00000000..2e58020b
--- /dev/null
+++ b/userapi/consumers/syncapi_readupdate.go
@@ -0,0 +1,136 @@
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ uapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/producers"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/util"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+type OutputReadUpdateConsumer struct {
+ ctx context.Context
+ cfg *config.UserAPI
+ jetstream nats.JetStreamContext
+ durable string
+ db storage.Database
+ pgClient pushgateway.Client
+ ServerName gomatrixserverlib.ServerName
+ topic string
+ userAPI uapi.UserInternalAPI
+ syncProducer *producers.SyncAPI
+}
+
+func NewOutputReadUpdateConsumer(
+ process *process.ProcessContext,
+ cfg *config.UserAPI,
+ js nats.JetStreamContext,
+ store storage.Database,
+ pgClient pushgateway.Client,
+ userAPI uapi.UserInternalAPI,
+ syncProducer *producers.SyncAPI,
+) *OutputReadUpdateConsumer {
+ return &OutputReadUpdateConsumer{
+ ctx: process.Context(),
+ cfg: cfg,
+ jetstream: js,
+ db: store,
+ ServerName: cfg.Matrix.ServerName,
+ durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
+ topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
+ pgClient: pgClient,
+ userAPI: userAPI,
+ syncProducer: syncProducer,
+ }
+}
+
+func (s *OutputReadUpdateConsumer) Start() error {
+ if err := jetstream.JetStreamConsumer(
+ s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
+ nats.DeliverAll(), nats.ManualAck(),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
+ var read types.ReadUpdate
+ if err := json.Unmarshal(msg.Data, &read); err != nil {
+ log.WithError(err).Error("userapi clientapi consumer: message parse failure")
+ return true
+ }
+ if read.FullyRead == 0 && read.Read == 0 {
+ return true
+ }
+
+ userID := string(msg.Header.Get(jetstream.UserID))
+ roomID := string(msg.Header.Get(jetstream.RoomID))
+
+ localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
+ return true
+ }
+ if domain != s.ServerName {
+ log.Error("userapi clientapi consumer: not a local user")
+ return true
+ }
+
+ log := log.WithFields(log.Fields{
+ "room_id": roomID,
+ "user_id": userID,
+ })
+ log.Tracef("Received read update from sync API: %#v", read)
+
+ if read.Read > 0 {
+ updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
+ if err != nil {
+ log.WithError(err).Error("userapi EDU consumer")
+ return false
+ }
+
+ if updated {
+ if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
+ log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
+ return false
+ }
+ if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
+ log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
+ return false
+ }
+ }
+ }
+
+ if read.FullyRead > 0 {
+ deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
+ if err != nil {
+ log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
+ return false
+ }
+
+ if deleted {
+ if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
+ log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
+ return false
+ }
+
+ if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
+ log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
+ return false
+ }
+ }
+ }
+
+ return true
+}
diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go
new file mode 100644
index 00000000..11081327
--- /dev/null
+++ b/userapi/consumers/syncapi_streamevent.go
@@ -0,0 +1,588 @@
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/internal/pushrules"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/producers"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/util"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+type OutputStreamEventConsumer struct {
+ ctx context.Context
+ cfg *config.UserAPI
+ userAPI api.UserInternalAPI
+ rsAPI rsapi.RoomserverInternalAPI
+ jetstream nats.JetStreamContext
+ durable string
+ db storage.Database
+ topic string
+ pgClient pushgateway.Client
+ syncProducer *producers.SyncAPI
+}
+
+func NewOutputStreamEventConsumer(
+ process *process.ProcessContext,
+ cfg *config.UserAPI,
+ js nats.JetStreamContext,
+ store storage.Database,
+ pgClient pushgateway.Client,
+ userAPI api.UserInternalAPI,
+ rsAPI rsapi.RoomserverInternalAPI,
+ syncProducer *producers.SyncAPI,
+) *OutputStreamEventConsumer {
+ return &OutputStreamEventConsumer{
+ ctx: process.Context(),
+ cfg: cfg,
+ jetstream: js,
+ db: store,
+ durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"),
+ topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
+ pgClient: pgClient,
+ userAPI: userAPI,
+ rsAPI: rsAPI,
+ syncProducer: syncProducer,
+ }
+}
+
+func (s *OutputStreamEventConsumer) Start() error {
+ if err := jetstream.JetStreamConsumer(
+ s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
+ nats.DeliverAll(), nats.ManualAck(),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
+ var output types.StreamedEvent
+ output.Event = &gomatrixserverlib.HeaderedEvent{}
+ if err := json.Unmarshal(msg.Data, &output); err != nil {
+ log.WithError(err).Errorf("userapi consumer: message parse failure")
+ return true
+ }
+ if output.Event.Event == nil {
+ log.Errorf("userapi consumer: expected event")
+ return true
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": output.Event.EventID(),
+ "event_type": output.Event.Type(),
+ "stream_pos": output.StreamPosition,
+ }).Tracef("Received message from sync API: %#v", output)
+
+ if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil {
+ log.WithFields(log.Fields{
+ "event_id": output.Event.EventID(),
+ }).WithError(err).Errorf("userapi consumer: process room event failure")
+ }
+
+ return true
+}
+
+func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error {
+ members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
+ if err != nil {
+ return fmt.Errorf("s.localRoomMembers: %w", err)
+ }
+
+ if event.Type() == gomatrixserverlib.MRoomMember {
+ cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll)
+ var member *localMembership
+ member, err = newLocalMembership(&cevent)
+ if err != nil {
+ return fmt.Errorf("newLocalMembership: %w", err)
+ }
+ if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName {
+ // localRoomMembers only adds joined members. An invite
+ // should also be pushed to the target user.
+ members = append(members, member)
+ }
+ }
+
+ // TODO: run in parallel with localRoomMembers.
+ roomName, err := s.roomName(ctx, event)
+ if err != nil {
+ return fmt.Errorf("s.roomName: %w", err)
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "num_members": len(members),
+ "room_size": roomSize,
+ }).Tracef("Notifying members")
+
+ // Notification.UserIsTarget is a per-member field, so we
+ // cannot group all users in a single request.
+ //
+ // TODO: does it have to be set? It's not required, and
+ // removing it means we can send all notifications to
+ // e.g. Element's Push gateway in one go.
+ for _, mem := range members {
+ if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil {
+ log.WithFields(log.Fields{
+ "localpart": mem.Localpart,
+ }).WithError(err).Debugf("Unable to push to local user")
+ continue
+ }
+ }
+
+ return nil
+}
+
+type localMembership struct {
+ gomatrixserverlib.MemberContent
+ UserID string
+ Localpart string
+ Domain gomatrixserverlib.ServerName
+}
+
+func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) {
+ if event.StateKey == nil {
+ return nil, fmt.Errorf("missing state_key")
+ }
+
+ var member localMembership
+ if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil {
+ return nil, err
+ }
+
+ localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey)
+ if err != nil {
+ return nil, err
+ }
+
+ member.UserID = *event.StateKey
+ member.Localpart = localpart
+ member.Domain = domain
+ return &member, nil
+}
+
+// localRoomMembers fetches the current local members of a room, and
+// the total number of members.
+func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
+ req := &rsapi.QueryMembershipsForRoomRequest{
+ RoomID: roomID,
+ JoinedOnly: true,
+ }
+ var res rsapi.QueryMembershipsForRoomResponse
+
+ // XXX: This could potentially race if the state for the event is not known yet
+ // e.g. the event came over federation but we do not have the full state persisted.
+ if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
+ return nil, 0, err
+ }
+
+ var members []*localMembership
+ var ntotal int
+ for _, event := range res.JoinEvents {
+ member, err := newLocalMembership(&event)
+ if err != nil {
+ log.WithError(err).Errorf("Parsing MemberContent")
+ continue
+ }
+ if member.Membership != gomatrixserverlib.Join {
+ continue
+ }
+ if member.Domain != s.cfg.Matrix.ServerName {
+ continue
+ }
+
+ ntotal++
+ members = append(members, member)
+ }
+
+ return members, ntotal, nil
+}
+
+// roomName returns the name in the event (if type==m.room.name), or
+// looks it up in roomserver. If there is no name,
+// m.room.canonical_alias is consulted. Returns an empty string if the
+// room has no name.
+func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
+ if event.Type() == gomatrixserverlib.MRoomName {
+ name, err := unmarshalRoomName(event)
+ if err != nil {
+ return "", err
+ }
+
+ if name != "" {
+ return name, nil
+ }
+ }
+
+ req := &rsapi.QueryCurrentStateRequest{
+ RoomID: event.RoomID(),
+ StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple},
+ }
+ var res rsapi.QueryCurrentStateResponse
+
+ if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil {
+ return "", nil
+ }
+
+ if eventS := res.StateEvents[roomNameTuple]; eventS != nil {
+ return unmarshalRoomName(eventS)
+ }
+
+ if event.Type() == gomatrixserverlib.MRoomCanonicalAlias {
+ alias, err := unmarshalCanonicalAlias(event)
+ if err != nil {
+ return "", err
+ }
+
+ if alias != "" {
+ return alias, nil
+ }
+ }
+
+ if event = res.StateEvents[canonicalAliasTuple]; event != nil {
+ return unmarshalCanonicalAlias(event)
+ }
+
+ return "", nil
+}
+
+var (
+ canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias}
+ roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName}
+)
+
+func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) {
+ var nc eventutil.NameContent
+ if err := json.Unmarshal(event.Content(), &nc); err != nil {
+ return "", fmt.Errorf("unmarshaling NameContent: %w", err)
+ }
+
+ return nc.Name, nil
+}
+
+func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) {
+ var cac eventutil.CanonicalAliasContent
+ if err := json.Unmarshal(event.Content(), &cac); err != nil {
+ return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err)
+ }
+
+ return cac.Alias, nil
+}
+
+// notifyLocal finds the right push actions for a local user, given an event.
+func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error {
+ actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
+ if err != nil {
+ return err
+ }
+ a, tweaks, err := pushrules.ActionsToTweaks(actions)
+ if err != nil {
+ return err
+ }
+ // TODO: support coalescing.
+ if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "localpart": mem.Localpart,
+ }).Tracef("Push rule evaluation rejected the event")
+ return nil
+ }
+
+ devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
+ if err != nil {
+ return err
+ }
+
+ n := &api.Notification{
+ Actions: actions,
+ // UNSPEC: the spec doesn't say this is a ClientEvent, but the
+ // fields seem to match. room_id should be missing, which
+ // matches the behaviour of FormatSync.
+ Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync),
+ // TODO: this is per-device, but it's not part of the primary
+ // key. So inserting one notification per profile tag doesn't
+ // make sense. What is this supposed to be? Sytests require it
+ // to "work", but they only use a single device.
+ ProfileTag: profileTag,
+ RoomID: event.RoomID(),
+ TS: gomatrixserverlib.AsTimestamp(time.Now()),
+ }
+ if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil {
+ return err
+ }
+
+ if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
+ return err
+ }
+
+ // We do this after InsertNotification. Thus, this should always return >=1.
+ userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "localpart": mem.Localpart,
+ "num_urls": len(devicesByURLAndFormat),
+ "num_unread": userNumUnreadNotifs,
+ }).Tracef("Notifying single member")
+
+ // Push gateways are out of our control, and we cannot risk
+ // looking up the server on a misbehaving push gateway. Each user
+ // receives a goroutine now that all internal API calls have been
+ // made.
+ //
+ // TODO: think about bounding this to one per user, and what
+ // ordering guarantees we must provide.
+ go func() {
+ // This background processing cannot be tied to a request.
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ var rejected []*pushgateway.Device
+ for url, fmts := range devicesByURLAndFormat {
+ for format, devices := range fmts {
+ // TODO: support "email".
+ if !strings.HasPrefix(url, "http") {
+ continue
+ }
+
+ // UNSPEC: the specification suggests there can be
+ // more than one device per request. There is at least
+ // one Sytest that expects one HTTP request per
+ // device, rather than per URL. For now, we must
+ // notify each one separately.
+ for _, dev := range devices {
+ rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs))
+ if err != nil {
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "localpart": mem.Localpart,
+ }).WithError(err).Errorf("Unable to notify HTTP pusher")
+ continue
+ }
+ rejected = append(rejected, rej...)
+ }
+ }
+ }
+
+ if len(rejected) > 0 {
+ s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
+ }
+ }()
+
+ return nil
+}
+
+// evaluatePushRules fetches and evaluates the push rules of a local
+// user. Returns actions (including dont_notify).
+func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
+ if event.Sender() == mem.UserID {
+ // SPEC: Homeservers MUST NOT notify the Push Gateway for
+ // events that the user has sent themselves.
+ return nil, nil
+ }
+
+ var res api.QueryPushRulesResponse
+ if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil {
+ return nil, err
+ }
+
+ ec := &ruleSetEvalContext{
+ ctx: ctx,
+ rsAPI: s.rsAPI,
+ mem: mem,
+ roomID: event.RoomID(),
+ roomSize: roomSize,
+ }
+ eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global)
+ rule, err := eval.MatchEvent(event.Event)
+ if err != nil {
+ return nil, err
+ }
+ if rule == nil {
+ // SPEC: If no rules match an event, the homeserver MUST NOT
+ // notify the Push Gateway for that event.
+ return nil, err
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "localpart": mem.Localpart,
+ "rule_id": rule.RuleID,
+ }).Tracef("Matched a push rule")
+
+ return rule.Actions, nil
+}
+
+type ruleSetEvalContext struct {
+ ctx context.Context
+ rsAPI rsapi.RoomserverInternalAPI
+ mem *localMembership
+ roomID string
+ roomSize int
+}
+
+func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName }
+
+func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
+
+func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
+ req := &rsapi.QueryLatestEventsAndStateRequest{
+ RoomID: rse.roomID,
+ StateToFetch: []gomatrixserverlib.StateKeyTuple{
+ {EventType: gomatrixserverlib.MRoomPowerLevels},
+ },
+ }
+ var res rsapi.QueryLatestEventsAndStateResponse
+ if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil {
+ return false, err
+ }
+ for _, ev := range res.StateEvents {
+ if ev.Type() != gomatrixserverlib.MRoomPowerLevels {
+ continue
+ }
+
+ plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event)
+ if err != nil {
+ return false, err
+ }
+ return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
+ }
+ return true, nil
+}
+
+// localPushDevices pushes to the configured devices of a local
+// user. The map keys are [url][format].
+func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
+ pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
+ if err != nil {
+ return nil, "", err
+ }
+
+ var profileTag string
+ devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices))
+ for _, pusherDevice := range pusherDevices {
+ if profileTag == "" {
+ profileTag = pusherDevice.Pusher.ProfileTag
+ }
+
+ url := pusherDevice.URL
+ if devicesByURL[url] == nil {
+ devicesByURL[url] = make(map[string][]*pushgateway.Device, 2)
+ }
+ devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device)
+ }
+
+ return devicesByURL, profileTag, nil
+}
+
+// notifyHTTP performs a notificatation to a Push Gateway.
+func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
+ logger := log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "url": url,
+ "localpart": localpart,
+ "num_devices": len(devices),
+ })
+
+ var req pushgateway.NotifyRequest
+ switch format {
+ case "event_id_only":
+ req = pushgateway.NotifyRequest{
+ Notification: pushgateway.Notification{
+ Counts: &pushgateway.Counts{},
+ Devices: devices,
+ EventID: event.EventID(),
+ RoomID: event.RoomID(),
+ },
+ }
+
+ default:
+ req = pushgateway.NotifyRequest{
+ Notification: pushgateway.Notification{
+ Content: event.Content(),
+ Counts: &pushgateway.Counts{
+ Unread: userNumUnreadNotifs,
+ },
+ Devices: devices,
+ EventID: event.EventID(),
+ ID: event.EventID(),
+ RoomID: event.RoomID(),
+ RoomName: roomName,
+ Sender: event.Sender(),
+ Type: event.Type(),
+ },
+ }
+ if mem, err := event.Membership(); err == nil {
+ req.Notification.Membership = mem
+ }
+ if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) {
+ req.Notification.UserIsTarget = true
+ }
+ }
+
+ logger.Debugf("Notifying push gateway %s", url)
+ var res pushgateway.NotifyResponse
+ if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
+ logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
+ return nil, err
+ }
+ logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result")
+
+ if len(res.Rejected) == 0 {
+ return nil, nil
+ }
+
+ devMap := make(map[string]*pushgateway.Device, len(devices))
+ for _, d := range devices {
+ devMap[d.PushKey] = d
+ }
+ rejected := make([]*pushgateway.Device, 0, len(res.Rejected))
+ for _, pushKey := range res.Rejected {
+ d := devMap[pushKey]
+ if d != nil {
+ rejected = append(rejected, d)
+ }
+ }
+
+ return rejected, nil
+}
+
+// deleteRejectedPushers deletes the pushers associated with the given devices.
+func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id0": devices[0].AppID,
+ "num_devices": len(devices),
+ }).Warnf("Deleting pushers rejected by the HTTP push gateway")
+
+ for _, d := range devices {
+ if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ }).WithError(err).Errorf("Unable to delete rejected pusher")
+ }
+ }
+}
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index f54cc613..7a42fc60 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -20,6 +20,8 @@ import (
"encoding/json"
"errors"
"fmt"
+ "strconv"
+ "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -27,16 +29,22 @@ import (
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
+ "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
type UserInternalAPI struct {
- DB storage.Database
- ServerName gomatrixserverlib.ServerName
+ DB storage.Database
+ SyncProducer *producers.SyncAPI
+
+ DisableTLSValidation bool
+ ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
KeyAPI keyapi.KeyInternalAPI
@@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
}
res.Keys = result
}
+
+func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
+ if req.Limit == 0 || req.Limit > 1000 {
+ req.Limit = 1000
+ }
+
+ var fromID int64
+ var err error
+ if req.From != "" {
+ fromID, err = strconv.ParseInt(req.From, 10, 64)
+ if err != nil {
+ return fmt.Errorf("QueryNotifications: parsing 'from': %w", err)
+ }
+ }
+ var filter tables.NotificationFilter = tables.AllNotifications
+ if req.Only == "highlight" {
+ filter = tables.HighlightNotifications
+ }
+ notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
+ if err != nil {
+ return err
+ }
+ if notifs == nil {
+ // This ensures empty is JSON-encoded as [] instead of null.
+ notifs = []*api.Notification{}
+ }
+ res.Notifications = notifs
+ if lastID >= 0 {
+ res.NextToken = strconv.FormatInt(lastID+1, 10)
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error {
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "localpart": req.Localpart,
+ "pushkey": req.Pusher.PushKey,
+ "display_name": req.Pusher.AppDisplayName,
+ }).Info("PerformPusherCreation")
+ if !req.Append {
+ err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey)
+ if err != nil {
+ return err
+ }
+ }
+ if req.Pusher.Kind == "" {
+ return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
+ }
+ if req.Pusher.PushKeyTS == 0 {
+ req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
+ }
+ return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
+}
+
+func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
+ pushers, err := a.DB.GetPushers(ctx, req.Localpart)
+ if err != nil {
+ return err
+ }
+ for i := range pushers {
+ logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
+ if pushers[i].SessionID != req.SessionID {
+ err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
+ var err error
+ res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
+ return err
+}
+
+func (a *UserInternalAPI) PerformPushRulesPut(
+ ctx context.Context,
+ req *api.PerformPushRulesPutRequest,
+ _ *struct{},
+) error {
+ bs, err := json.Marshal(&req.RuleSets)
+ if err != nil {
+ return err
+ }
+ userReq := api.InputAccountDataRequest{
+ UserID: req.UserID,
+ DataType: pushRulesAccountDataType,
+ AccountData: json.RawMessage(bs),
+ }
+ var userRes api.InputAccountDataResponse // empty
+ if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
+ return err
+ }
+
+ if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil {
+ util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed")
+ }
+
+ return nil
+}
+
+func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
+ userReq := api.QueryAccountDataRequest{
+ UserID: req.UserID,
+ DataType: pushRulesAccountDataType,
+ }
+ var userRes api.QueryAccountDataResponse
+ if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil {
+ return err
+ }
+ bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType]
+ if ok {
+ // Legacy Dendrite users will have completely empty push rules, so we should
+ // detect that situation and set some defaults.
+ var rules struct {
+ G struct {
+ Content []json.RawMessage `json:"content"`
+ Override []json.RawMessage `json:"override"`
+ Room []json.RawMessage `json:"room"`
+ Sender []json.RawMessage `json:"sender"`
+ Underride []json.RawMessage `json:"underride"`
+ } `json:"global"`
+ }
+ if err := json.Unmarshal([]byte(bs), &rules); err == nil {
+ count := len(rules.G.Content) + len(rules.G.Override) +
+ len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride)
+ ok = count > 0
+ }
+ }
+ if !ok {
+ // If we didn't find any default push rules then we should just generate some
+ // fresh ones.
+ localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
+ if err != nil {
+ return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
+ }
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName)
+ prbs, err := json.Marshal(pushRuleSets)
+ if err != nil {
+ return fmt.Errorf("failed to marshal default push rules: %w", err)
+ }
+ if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil {
+ return fmt.Errorf("failed to save default push rules: %w", err)
+ }
+ res.RuleSets = pushRuleSets
+ return nil
+ }
+ var data pushrules.AccountRuleSets
+ if err := json.Unmarshal([]byte(bs), &data); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed")
+ return err
+ }
+ res.RuleSets = &data
+ return nil
+}
+
+const pushRulesAccountDataType = "m.push_rules"
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 1599d463..8ec649ad 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -37,6 +37,9 @@ const (
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
PerformKeyBackupPath = "/userapi/performKeyBackup"
+ PerformPusherSetPath = "/pushserver/performPusherSet"
+ PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
+ PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile"
@@ -46,6 +49,9 @@ const (
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
+ QueryPushersPath = "/pushserver/queryPushers"
+ QueryPushRulesPath = "/pushserver/queryPushRules"
+ QueryNotificationsPath = "/pushserver/queryNotifications"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@@ -249,3 +255,58 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query
res.Error = err.Error()
}
}
+
+func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
+ defer span.Finish()
+
+ return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
+}
+
+func (h *httpUserInternalAPI) PerformPusherSet(
+ ctx context.Context,
+ request *api.PerformPusherSetRequest,
+ response *struct{},
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPusherSetPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPusherDeletionPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
+
+func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
+ defer span.Finish()
+
+ apiURL := h.apiURL + QueryPushersPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
+
+func (h *httpUserInternalAPI) PerformPushRulesPut(
+ ctx context.Context,
+ request *api.PerformPushRulesPutRequest,
+ response *struct{},
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPushRulesPutPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
+ defer span.Finish()
+
+ apiURL := h.apiURL + QueryPushRulesPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index d00ee042..526f9957 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(QueryNotificationsPath,
+ httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
+ var request api.QueryNotificationsRequest
+ var response api.QueryNotificationsResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(PerformPusherSetPath,
+ httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPusherSetRequest{}
+ response := struct{}{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ internalAPIMux.Handle(PerformPusherDeletionPath,
+ httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPusherDeletionRequest{}
+ response := struct{}{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(QueryPushersPath,
+ httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse {
+ request := api.QueryPushersRequest{}
+ response := api.QueryPushersResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(PerformPushRulesPutPath,
+ httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPushRulesPutRequest{}
+ response := struct{}{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(QueryPushRulesPath,
+ httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse {
+ request := api.QueryPushRulesRequest{}
+ response := api.QueryPushRulesResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
}
diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go
new file mode 100644
index 00000000..4a206f33
--- /dev/null
+++ b/userapi/producers/syncapi.go
@@ -0,0 +1,104 @@
+package producers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+type JetStreamPublisher interface {
+ PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error)
+}
+
+// SyncAPI produces messages for the Sync API server to consume.
+type SyncAPI struct {
+ db storage.Database
+ producer JetStreamPublisher
+ clientDataTopic string
+ notificationDataTopic string
+}
+
+func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
+ return &SyncAPI{
+ db: db,
+ producer: js,
+ clientDataTopic: clientDataTopic,
+ notificationDataTopic: notificationDataTopic,
+ }
+}
+
+// SendAccountData sends account data to the Sync API server.
+func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error {
+ m := &nats.Msg{
+ Subject: p.clientDataTopic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.UserID, userID)
+
+ var err error
+ m.Data, err = json.Marshal(eventutil.AccountData{
+ RoomID: roomID,
+ Type: dataType,
+ })
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "user_id": userID,
+ "room_id": roomID,
+ "data_type": dataType,
+ }).Tracef("Producing to topic '%s'", p.clientDataTopic)
+
+ _, err = p.producer.PublishMsg(m)
+ return err
+}
+
+// GetAndSendNotificationData reads the database and sends data about unread
+// notifications to the Sync API server.
+func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
+ localpart, _, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+
+ ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
+ if err != nil {
+ return err
+ }
+
+ return p.sendNotificationData(userID, &eventutil.NotificationData{
+ RoomID: roomID,
+ UnreadHighlightCount: int(nhighlight),
+ UnreadNotificationCount: int(ntotal),
+ })
+}
+
+// sendNotificationData sends data about unread notifications to the Sync API server.
+func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error {
+ m := &nats.Msg{
+ Subject: p.notificationDataTopic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.UserID, userID)
+
+ var err error
+ m.Data, err = json.Marshal(data)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "user_id": userID,
+ "room_id": data.RoomID,
+ }).Tracef("Producing to topic '%s'", p.clientDataTopic)
+
+ _, err = p.producer.PublishMsg(m)
+ return err
+}
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index a131dac4..6d22fea9 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
type Database interface {
@@ -89,6 +90,18 @@ type Database interface {
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
+
+ InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
+ DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
+ SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
+ GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
+ GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
+ GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
+
+ UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
+ GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
+ RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
+ RemovePushers(ctx context.Context, appid, pushkey string) error
}
// Err3PIDInUse is the error returned when trying to save an association involving
diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go
new file mode 100644
index 00000000..7bcc0f9c
--- /dev/null
+++ b/userapi/storage/postgres/notifications_table.go
@@ -0,0 +1,219 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+type notificationsStatements struct {
+ insertStmt *sql.Stmt
+ deleteUpToStmt *sql.Stmt
+ updateReadStmt *sql.Stmt
+ selectStmt *sql.Stmt
+ selectCountStmt *sql.Stmt
+ selectRoomCountsStmt *sql.Stmt
+}
+
+const notificationSchema = `
+CREATE TABLE IF NOT EXISTS userapi_notifications (
+ id BIGSERIAL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ stream_pos BIGINT NOT NULL,
+ ts_ms BIGINT NOT NULL,
+ highlight BOOLEAN NOT NULL,
+ notification_json TEXT NOT NULL,
+ read BOOLEAN NOT NULL DEFAULT FALSE
+);
+
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+`
+
+const insertNotificationSQL = "" +
+ "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+
+const deleteNotificationsUpToSQL = "" +
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+
+const updateNotificationReadSQL = "" +
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+
+const selectNotificationSQL = "" +
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $4"
+
+const selectNotificationCountSQL = "" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
+ "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read"
+
+const selectRoomNotificationCountsSQL = "" +
+ "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
+ "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+
+func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
+ s := &notificationsStatements{}
+ _, err := db.Exec(notificationSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertStmt, insertNotificationSQL},
+ {&s.deleteUpToStmt, deleteNotificationsUpToSQL},
+ {&s.updateReadStmt, updateNotificationReadSQL},
+ {&s.selectStmt, selectNotificationSQL},
+ {&s.selectCountStmt, selectNotificationCountSQL},
+ {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
+ }.Prepare(db)
+}
+
+// Insert inserts a notification into the database.
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
+ roomID, tsMS := n.RoomID, n.TS
+ nn := *n
+ // Clears out fields that have their own columns to (1) shrink the
+ // data and (2) avoid difficult-to-debug inconsistency bugs.
+ nn.RoomID = ""
+ nn.TS, nn.Read = 0, false
+ bs, err := json.Marshal(nn)
+ if err != nil {
+ return err
+ }
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ return err
+}
+
+// DeleteUpTo deletes all previous notifications, up to and including the event.
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+// UpdateRead updates the "read" value for an event.
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ var maxID int64 = -1
+ var notifs []*api.Notification
+ for rows.Next() {
+ var id int64
+ var roomID string
+ var ts gomatrixserverlib.Timestamp
+ var read bool
+ var jsonStr string
+ err = rows.Scan(
+ &id,
+ &roomID,
+ &ts,
+ &read,
+ &jsonStr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ var n api.Notification
+ err := json.Unmarshal([]byte(jsonStr), &n)
+ if err != nil {
+ return nil, 0, err
+ }
+ n.RoomID = roomID
+ n.TS = ts
+ n.Read = read
+ notifs = append(notifs, &n)
+
+ if maxID < id {
+ maxID = id
+ }
+ }
+ return notifs, maxID, rows.Err()
+}
+
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
+
+ if err != nil {
+ return 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var count int64
+ if err := rows.Scan(&count); err != nil {
+ return 0, err
+ }
+
+ return count, nil
+ }
+ return 0, rows.Err()
+}
+
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
+
+ if err != nil {
+ return 0, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var total, highlight int64
+ if err := rows.Scan(&total, &highlight); err != nil {
+ return 0, 0, err
+ }
+
+ return total, highlight, nil
+ }
+ return 0, 0, rows.Err()
+}
diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go
new file mode 100644
index 00000000..670dc916
--- /dev/null
+++ b/userapi/storage/postgres/pusher_table.go
@@ -0,0 +1,157 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
+const pushersSchema = `
+CREATE TABLE IF NOT EXISTS userapi_pushers (
+ id BIGSERIAL PRIMARY KEY,
+ -- The Matrix user ID localpart for this pusher
+ localpart TEXT NOT NULL,
+ session_id BIGINT DEFAULT NULL,
+ profile_tag TEXT,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
+ lang TEXT NOT NULL,
+ data TEXT NOT NULL
+);
+
+-- For faster deleting by app_id, pushkey pair.
+CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
+
+-- For faster retrieving by localpart.
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+
+-- Pushkey must be unique for a given user and app.
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+`
+
+const insertPusherSQL = "" +
+ "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
+ "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+
+const selectPushersSQL = "" +
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+
+const deletePusherSQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+
+const deletePushersByAppIdAndPushKeySQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
+
+func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) {
+ s := &pushersStatements{}
+ _, err := db.Exec(pushersSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertPusherStmt, insertPusherSQL},
+ {&s.selectPushersStmt, selectPushersSQL},
+ {&s.deletePusherStmt, deletePusherSQL},
+ {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
+ }.Prepare(db)
+}
+
+type pushersStatements struct {
+ insertPusherStmt *sql.Stmt
+ selectPushersStmt *sql.Stmt
+ deletePusherStmt *sql.Stmt
+ deletePushersByAppIdAndPushKeyStmt *sql.Stmt
+}
+
+// insertPusher creates a new pusher.
+// Returns an error if the user already has a pusher with the given pusher pushkey.
+// Returns nil error success.
+func (s *pushersStatements) InsertPusher(
+ ctx context.Context, txn *sql.Tx, session_id int64,
+ pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
+ logrus.Debugf("Created pusher %d", session_id)
+ return err
+}
+
+func (s *pushersStatements) SelectPushers(
+ ctx context.Context, txn *sql.Tx, localpart string,
+) ([]api.Pusher, error) {
+ pushers := []api.Pusher{}
+ rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
+
+ if err != nil {
+ return pushers, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
+
+ for rows.Next() {
+ var pusher api.Pusher
+ var data []byte
+ err = rows.Scan(
+ &pusher.SessionID,
+ &pusher.PushKey,
+ &pusher.PushKeyTS,
+ &pusher.Kind,
+ &pusher.AppID,
+ &pusher.AppDisplayName,
+ &pusher.DeviceDisplayName,
+ &pusher.ProfileTag,
+ &pusher.Language,
+ &data)
+ if err != nil {
+ return pushers, err
+ }
+ err := json.Unmarshal(data, &pusher.Data)
+ if err != nil {
+ return pushers, err
+ }
+ pushers = append(pushers, pusher)
+ }
+
+ logrus.Debugf("Database returned %d pushers", len(pushers))
+ return pushers, rows.Err()
+}
+
+// deletePusher removes a single pusher by pushkey and user localpart.
+func (s *pushersStatements) DeletePusher(
+ ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
+ return err
+}
+
+func (s *pushersStatements) DeletePushers(
+ ctx context.Context, txn *sql.Tx, appid, pushkey string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
+ return err
+}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index ac5c59b8..c74a999f 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
+ pusherTable, err := NewPostgresPusherTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
+ }
+ notificationsTable, err := NewPostgresNotificationTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
+ }
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
+ Pushers: pusherTable,
+ Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 5f1f9500..a58974b4 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -29,6 +29,7 @@ import (
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
@@ -47,6 +48,8 @@ type Database struct {
KeyBackupVersions tables.KeyBackupVersionTable
Devices tables.DevicesTable
LoginTokens tables.LoginTokenTable
+ Notifications tables.NotificationTable
+ Pushers tables.PusherTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
BcryptCost int
@@ -160,15 +163,12 @@ func (d *Database) createAccount(
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
- if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
- "global": {
- "content": [],
- "override": [],
- "room": [],
- "sender": [],
- "underride": []
- }
- }`)); err != nil {
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
+ prbs, err := json.Marshal(pushRuleSets)
+ if err != nil {
+ return nil, err
+ }
+ if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err
}
return account, nil
@@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token)
}
+
+func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
+ })
+}
+
+func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
+ return err
+ })
+ return
+}
+
+func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
+ return err
+ })
+ return
+}
+
+func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
+}
+
+func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
+ return d.Notifications.SelectCount(ctx, nil, localpart, filter)
+}
+
+func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
+ return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
+}
+
+func (d *Database) UpsertPusher(
+ ctx context.Context, p api.Pusher, localpart string,
+) error {
+ data, err := json.Marshal(p.Data)
+ if err != nil {
+ return err
+ }
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Pushers.InsertPusher(
+ ctx, txn,
+ p.SessionID,
+ p.PushKey,
+ p.PushKeyTS,
+ p.Kind,
+ p.AppID,
+ p.AppDisplayName,
+ p.DeviceDisplayName,
+ p.ProfileTag,
+ p.Language,
+ string(data),
+ localpart)
+ })
+}
+
+// GetPushers returns the pushers matching the given localpart.
+func (d *Database) GetPushers(
+ ctx context.Context, localpart string,
+) ([]api.Pusher, error) {
+ return d.Pushers.SelectPushers(ctx, nil, localpart)
+}
+
+// RemovePusher deletes one pusher
+// Invoked when `append` is true and `kind` is null in
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
+func (d *Database) RemovePusher(
+ ctx context.Context, appid, pushkey, localpart string,
+) error {
+ return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
+ err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
+ if err == sql.ErrNoRows {
+ return nil
+ }
+ return err
+ })
+}
+
+// RemovePushers deletes all pushers that match given App Id and Push Key pair.
+// Invoked when `append` parameter is false in
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
+func (d *Database) RemovePushers(
+ ctx context.Context, appid, pushkey string,
+) error {
+ return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
+ return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
+ })
+}
diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go
new file mode 100644
index 00000000..fcfb1aad
--- /dev/null
+++ b/userapi/storage/sqlite3/notifications_table.go
@@ -0,0 +1,219 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+type notificationsStatements struct {
+ insertStmt *sql.Stmt
+ deleteUpToStmt *sql.Stmt
+ updateReadStmt *sql.Stmt
+ selectStmt *sql.Stmt
+ selectCountStmt *sql.Stmt
+ selectRoomCountsStmt *sql.Stmt
+}
+
+const notificationSchema = `
+CREATE TABLE IF NOT EXISTS userapi_notifications (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ localpart TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ stream_pos BIGINT NOT NULL,
+ ts_ms BIGINT NOT NULL,
+ highlight BOOLEAN NOT NULL,
+ notification_json TEXT NOT NULL,
+ read BOOLEAN NOT NULL DEFAULT FALSE
+);
+
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+`
+
+const insertNotificationSQL = "" +
+ "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+
+const deleteNotificationsUpToSQL = "" +
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+
+const updateNotificationReadSQL = "" +
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+
+const selectNotificationSQL = "" +
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $4"
+
+const selectNotificationCountSQL = "" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
+ "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read"
+
+const selectRoomNotificationCountsSQL = "" +
+ "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
+ "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+
+func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
+ s := &notificationsStatements{}
+ _, err := db.Exec(notificationSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertStmt, insertNotificationSQL},
+ {&s.deleteUpToStmt, deleteNotificationsUpToSQL},
+ {&s.updateReadStmt, updateNotificationReadSQL},
+ {&s.selectStmt, selectNotificationSQL},
+ {&s.selectCountStmt, selectNotificationCountSQL},
+ {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
+ }.Prepare(db)
+}
+
+// Insert inserts a notification into the database.
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
+ roomID, tsMS := n.RoomID, n.TS
+ nn := *n
+ // Clears out fields that have their own columns to (1) shrink the
+ // data and (2) avoid difficult-to-debug inconsistency bugs.
+ nn.RoomID = ""
+ nn.TS, nn.Read = 0, false
+ bs, err := json.Marshal(nn)
+ if err != nil {
+ return err
+ }
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ return err
+}
+
+// DeleteUpTo deletes all previous notifications, up to and including the event.
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+// UpdateRead updates the "read" value for an event.
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ var maxID int64 = -1
+ var notifs []*api.Notification
+ for rows.Next() {
+ var id int64
+ var roomID string
+ var ts gomatrixserverlib.Timestamp
+ var read bool
+ var jsonStr string
+ err = rows.Scan(
+ &id,
+ &roomID,
+ &ts,
+ &read,
+ &jsonStr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ var n api.Notification
+ err := json.Unmarshal([]byte(jsonStr), &n)
+ if err != nil {
+ return nil, 0, err
+ }
+ n.RoomID = roomID
+ n.TS = ts
+ n.Read = read
+ notifs = append(notifs, &n)
+
+ if maxID < id {
+ maxID = id
+ }
+ }
+ return notifs, maxID, rows.Err()
+}
+
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
+
+ if err != nil {
+ return 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var count int64
+ if err := rows.Scan(&count); err != nil {
+ return 0, err
+ }
+
+ return count, nil
+ }
+ return 0, rows.Err()
+}
+
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
+
+ if err != nil {
+ return 0, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var total, highlight int64
+ if err := rows.Scan(&total, &highlight); err != nil {
+ return 0, 0, err
+ }
+
+ return total, highlight, nil
+ }
+ return 0, 0, rows.Err()
+}
diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go
new file mode 100644
index 00000000..e718792e
--- /dev/null
+++ b/userapi/storage/sqlite3/pusher_table.go
@@ -0,0 +1,157 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
+const pushersSchema = `
+CREATE TABLE IF NOT EXISTS userapi_pushers (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The Matrix user ID localpart for this pusher
+ localpart TEXT NOT NULL,
+ session_id BIGINT DEFAULT NULL,
+ profile_tag TEXT,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
+ lang TEXT NOT NULL,
+ data TEXT NOT NULL
+);
+
+-- For faster deleting by app_id, pushkey pair.
+CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
+
+-- For faster retrieving by localpart.
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+
+-- Pushkey must be unique for a given user and app.
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+`
+
+const insertPusherSQL = "" +
+ "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
+ "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+
+const selectPushersSQL = "" +
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+
+const deletePusherSQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+
+const deletePushersByAppIdAndPushKeySQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
+
+func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) {
+ s := &pushersStatements{}
+ _, err := db.Exec(pushersSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertPusherStmt, insertPusherSQL},
+ {&s.selectPushersStmt, selectPushersSQL},
+ {&s.deletePusherStmt, deletePusherSQL},
+ {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
+ }.Prepare(db)
+}
+
+type pushersStatements struct {
+ insertPusherStmt *sql.Stmt
+ selectPushersStmt *sql.Stmt
+ deletePusherStmt *sql.Stmt
+ deletePushersByAppIdAndPushKeyStmt *sql.Stmt
+}
+
+// insertPusher creates a new pusher.
+// Returns an error if the user already has a pusher with the given pusher pushkey.
+// Returns nil error success.
+func (s *pushersStatements) InsertPusher(
+ ctx context.Context, txn *sql.Tx, session_id int64,
+ pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+) error {
+ _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
+ logrus.Debugf("Created pusher %d", session_id)
+ return err
+}
+
+func (s *pushersStatements) SelectPushers(
+ ctx context.Context, txn *sql.Tx, localpart string,
+) ([]api.Pusher, error) {
+ pushers := []api.Pusher{}
+ rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
+
+ if err != nil {
+ return pushers, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
+
+ for rows.Next() {
+ var pusher api.Pusher
+ var data []byte
+ err = rows.Scan(
+ &pusher.SessionID,
+ &pusher.PushKey,
+ &pusher.PushKeyTS,
+ &pusher.Kind,
+ &pusher.AppID,
+ &pusher.AppDisplayName,
+ &pusher.DeviceDisplayName,
+ &pusher.ProfileTag,
+ &pusher.Language,
+ &data)
+ if err != nil {
+ return pushers, err
+ }
+ err := json.Unmarshal(data, &pusher.Data)
+ if err != nil {
+ return pushers, err
+ }
+ pushers = append(pushers, pusher)
+ }
+
+ logrus.Debugf("Database returned %d pushers", len(pushers))
+ return pushers, rows.Err()
+}
+
+// deletePusher removes a single pusher by pushkey and user localpart.
+func (s *pushersStatements) DeletePusher(
+ ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+) error {
+ _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart)
+ return err
+}
+
+func (s *pushersStatements) DeletePushers(
+ ctx context.Context, txn *sql.Tx, appid, pushkey string,
+) error {
+ _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey)
+ return err
+}
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 98c24497..b5bb96c4 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
}
+ pusherTable, err := NewSQLitePusherTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
+ }
+ notificationsTable, err := NewSQLiteNotificationTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
+ }
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
+ Pushers: pusherTable,
+ Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 12939ced..815e5119 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
)
type AccountDataTable interface {
@@ -93,3 +94,42 @@ type ThreePIDTable interface {
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}
+
+type PusherTable interface {
+ InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
+ SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
+ DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
+ DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
+}
+
+type NotificationTable interface {
+ Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error
+ DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error)
+ UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error)
+ Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
+ SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
+ SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
+}
+
+type NotificationFilter uint32
+
+const (
+ // HighlightNotifications returns notifications that had a
+ // "highlight" tweak assigned to them from evaluating push rules.
+ HighlightNotifications NotificationFilter = 1 << iota
+
+ // NonHighlightNotifications returns notifications that don't
+ // match HighlightNotifications.
+ NonHighlightNotifications
+
+ // NoNotifications is a filter to exclude all types of
+ // notifications. It's useful as a zero value, but isn't likely to
+ // be used in a call to Notifications.Select*.
+ NoNotifications NotificationFilter = 0
+
+ // AllNotifications is a filter to include all types of
+ // notifications in Notifications.Select*. Note that PostgreSQL
+ // balks if this doesn't fit in INTEGER, even though we use
+ // uint32.
+ AllNotifications NotificationFilter = (1 << 31) - 1
+)
diff --git a/userapi/userapi.go b/userapi/userapi.go
index 4a5793ab..2382e951 100644
--- a/userapi/userapi.go
+++ b/userapi/userapi.go
@@ -18,11 +18,17 @@ import (
"time"
"github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/internal/pushgateway"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/consumers"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
+ "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus"
)
@@ -36,26 +42,49 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
- accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
+ base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI,
+ appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
+ rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client,
) api.UserInternalAPI {
db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db")
}
- return newInternalAPI(db, cfg, appServices, keyAPI)
-}
+ js := jetstream.Prepare(&cfg.Matrix.JetStream)
-func newInternalAPI(
- db storage.Database,
- cfg *config.UserAPI,
- appServices []config.ApplicationService,
- keyAPI keyapi.KeyInternalAPI,
-) api.UserInternalAPI {
- return &internal.UserInternalAPI{
- DB: db,
- ServerName: cfg.Matrix.ServerName,
- AppServices: appServices,
- KeyAPI: keyAPI,
+ syncProducer := producers.NewSyncAPI(
+ db, js,
+ // TODO: user API should handle syncs for account data. Right now,
+ // it's handled by clientapi, and hence uses its topic. When user
+ // API handles it for all account data, we can remove it from
+ // here.
+ cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
+ cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
+ )
+
+ userAPI := &internal.UserInternalAPI{
+ DB: db,
+ SyncProducer: syncProducer,
+ ServerName: cfg.Matrix.ServerName,
+ AppServices: appServices,
+ KeyAPI: keyAPI,
+ DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
+ }
+
+ readConsumer := consumers.NewOutputReadUpdateConsumer(
+ base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer,
+ )
+ if err := readConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start user API read update consumer")
+ }
+
+ eventConsumer := consumers.NewOutputStreamEventConsumer(
+ base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer,
+ )
+ if err := eventConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start user API streamed event consumer")
}
+
+ return userAPI
}
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 4214c07f..25319c4b 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage"
)
@@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
},
}
- return newInternalAPI(accountDB, cfg, nil, nil), accountDB
+ return &internal.UserInternalAPI{
+ DB: accountDB,
+ ServerName: cfg.Matrix.ServerName,
+ }, accountDB
}
func TestQueryProfile(t *testing.T) {
diff --git a/userapi/util/devices.go b/userapi/util/devices.go
new file mode 100644
index 00000000..cbf3bd28
--- /dev/null
+++ b/userapi/util/devices.go
@@ -0,0 +1,100 @@
+package util
+
+import (
+ "context"
+
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ log "github.com/sirupsen/logrus"
+)
+
+type PusherDevice struct {
+ Device pushgateway.Device
+ Pusher *api.Pusher
+ URL string
+ Format string
+}
+
+// GetPushDevices pushes to the configured devices of a local user.
+func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
+ pushers, err := db.GetPushers(ctx, localpart)
+ if err != nil {
+ return nil, err
+ }
+
+ devices := make([]*PusherDevice, 0, len(pushers))
+ for _, pusher := range pushers {
+ var url, format string
+ data := pusher.Data
+ switch pusher.Kind {
+ case api.EmailKind:
+ url = "mailto:"
+
+ case api.HTTPKind:
+ // TODO: The spec says only event_id_only is supported,
+ // but Sytests assume "" means "full notification".
+ fmtIface := pusher.Data["format"]
+ var ok bool
+ format, ok = fmtIface.(string)
+ if ok && format != "event_id_only" {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id": pusher.AppID,
+ }).Errorf("Only data.format event_id_only or empty is supported")
+ continue
+ }
+
+ urlIface := pusher.Data["url"]
+ url, ok = urlIface.(string)
+ if !ok {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id": pusher.AppID,
+ }).Errorf("No data.url configured for HTTP Pusher")
+ continue
+ }
+ data = mapWithout(data, "url")
+
+ default:
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id": pusher.AppID,
+ "kind": pusher.Kind,
+ }).Errorf("Unhandled pusher kind")
+ continue
+ }
+
+ devices = append(devices, &PusherDevice{
+ Device: pushgateway.Device{
+ AppID: pusher.AppID,
+ Data: data,
+ PushKey: pusher.PushKey,
+ PushKeyTS: pusher.PushKeyTS,
+ Tweaks: tweaks,
+ },
+ Pusher: &pusher,
+ URL: url,
+ Format: format,
+ })
+ }
+
+ return devices, nil
+}
+
+// mapWithout returns a shallow copy of the map, without the given
+// key. Returns nil if the resulting map is empty.
+func mapWithout(m map[string]interface{}, key string) map[string]interface{} {
+ ret := make(map[string]interface{}, len(m))
+ for k, v := range m {
+ // The specification says we do not send "url".
+ if k == key {
+ continue
+ }
+ ret[k] = v
+ }
+ if len(ret) == 0 {
+ return nil
+ }
+ return ret
+}
diff --git a/userapi/util/notify.go b/userapi/util/notify.go
new file mode 100644
index 00000000..ff206bd3
--- /dev/null
+++ b/userapi/util/notify.go
@@ -0,0 +1,76 @@
+package util
+
+import (
+ "context"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ log "github.com/sirupsen/logrus"
+)
+
+// NotifyUserCountsAsync sends notifications to a local user's
+// notification destinations. Database lookups run synchronously, but
+// a single goroutine is started when talking to the Push
+// gateways. There is no way to know when the background goroutine has
+// finished.
+func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
+ pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
+ if err != nil {
+ return err
+ }
+
+ if len(pusherDevices) == 0 {
+ return nil
+ }
+
+ userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id0": pusherDevices[0].Device.AppID,
+ "pushkey": pusherDevices[0].Device.PushKey,
+ }).Tracef("Notifying HTTP push gateway about notification counts")
+
+ // TODO: think about bounding this to one per user, and what
+ // ordering guarantees we must provide.
+ go func() {
+ // This background processing cannot be tied to a request.
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // TODO: we could batch all devices with the same URL, but
+ // Sytest requires consumers/roomserver.go to do it
+ // one-by-one, so we do the same here.
+ for _, pusherDevice := range pusherDevices {
+ // TODO: support "email".
+ if !strings.HasPrefix(pusherDevice.URL, "http") {
+ continue
+ }
+
+ req := pushgateway.NotifyRequest{
+ Notification: pushgateway.Notification{
+ Counts: &pushgateway.Counts{
+ Unread: int(userNumUnreadNotifs),
+ },
+ Devices: []*pushgateway.Device{&pusherDevice.Device},
+ },
+ }
+ if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id0": pusherDevice.Device.AppID,
+ "pushkey": pusherDevice.Device.PushKey,
+ }).WithError(err).Error("HTTP push gateway request failed")
+ return
+ }
+ }
+ }()
+
+ return nil
+}