aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore6
-rw-r--r--build/docker/config/dendrite.yaml11
-rw-r--r--build/gobind-pinecone/monolith.go2
-rw-r--r--build/gobind-yggdrasil/monolith.go2
-rw-r--r--clientapi/clientapi.go3
-rw-r--r--clientapi/producers/syncapi.go7
-rw-r--r--clientapi/routing/account_data.go12
-rw-r--r--clientapi/routing/notification.go63
-rw-r--r--clientapi/routing/password.go15
-rw-r--r--clientapi/routing/pusher.go114
-rw-r--r--clientapi/routing/pushrules.go386
-rw-r--r--clientapi/routing/room_tagging.go4
-rw-r--r--clientapi/routing/routing.go165
-rw-r--r--cmd/dendrite-demo-libp2p/main.go6
-rw-r--r--cmd/dendrite-demo-pinecone/main.go2
-rw-r--r--cmd/dendrite-demo-yggdrasil/main.go5
-rw-r--r--cmd/dendrite-monolith-server/main.go3
-rw-r--r--cmd/dendrite-polylith-multi/personalities/userapi.go6
-rw-r--r--cmd/dendritejs-pinecone/main.go6
-rw-r--r--cmd/dendritejs/main.go3
-rw-r--r--dendrite-config.yaml86
-rw-r--r--go.mod1
-rw-r--r--internal/eventutil/types.go24
-rw-r--r--internal/pushgateway/client.go66
-rw-r--r--internal/pushgateway/pushgateway.go62
-rw-r--r--internal/pushrules/action.go102
-rw-r--r--internal/pushrules/action_test.go39
-rw-r--r--internal/pushrules/condition.go49
-rw-r--r--internal/pushrules/default.go23
-rw-r--r--internal/pushrules/default_content.go33
-rw-r--r--internal/pushrules/default_override.go165
-rw-r--r--internal/pushrules/default_underride.go119
-rw-r--r--internal/pushrules/evaluate.go165
-rw-r--r--internal/pushrules/evaluate_test.go189
-rw-r--r--internal/pushrules/pushrules.go71
-rw-r--r--internal/pushrules/util.go125
-rw-r--r--internal/pushrules/util_test.go169
-rw-r--r--internal/pushrules/validate.go85
-rw-r--r--internal/pushrules/validate_test.go163
-rw-r--r--internal/sqlutil/sql.go1
-rw-r--r--q.sqlitebin0 -> 12288 bytes
-rwxr-xr-xrun-sytest.sh63
-rw-r--r--setup/base/base.go6
-rw-r--r--setup/config/config_test.go5
-rw-r--r--setup/config/config_userapi.go3
-rw-r--r--setup/jetstream/streams.go18
-rw-r--r--setup/monolith.go4
-rw-r--r--syncapi/consumers/clientapi.go77
-rw-r--r--syncapi/consumers/eduserver_receipts.go71
-rw-r--r--syncapi/consumers/roomserver.go10
-rw-r--r--syncapi/consumers/userapi.go110
-rw-r--r--syncapi/notifier/notifier.go11
-rw-r--r--syncapi/notifier/notifier_test.go2
-rw-r--r--syncapi/producers/userapi_readupdate.go62
-rw-r--r--syncapi/producers/userapi_streamevent.go60
-rw-r--r--syncapi/storage/interface.go8
-rw-r--r--syncapi/storage/postgres/notification_data_table.go108
-rw-r--r--syncapi/storage/postgres/syncserver.go5
-rw-r--r--syncapi/storage/shared/syncserver.go21
-rw-r--r--syncapi/storage/sqlite3/notification_data_table.go108
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go6
-rw-r--r--syncapi/storage/sqlite3/syncserver.go5
-rw-r--r--syncapi/storage/tables/interface.go7
-rw-r--r--syncapi/streams/stream_notificationdata.go55
-rw-r--r--syncapi/streams/streams.go34
-rw-r--r--syncapi/sync/requestpool.go9
-rw-r--r--syncapi/syncapi.go24
-rw-r--r--syncapi/types/types.go61
-rw-r--r--syncapi/types/types_test.go8
-rw-r--r--sytest-blacklist6
-rw-r--r--sytest-whitelist79
-rw-r--r--userapi/api/api.go83
-rw-r--r--userapi/api/api_trace.go30
-rw-r--r--userapi/consumers/syncapi_readupdate.go136
-rw-r--r--userapi/consumers/syncapi_streamevent.go588
-rw-r--r--userapi/internal/api.go171
-rw-r--r--userapi/inthttp/client.go61
-rw-r--r--userapi/inthttp/server.go82
-rw-r--r--userapi/producers/syncapi.go104
-rw-r--r--userapi/storage/interface.go13
-rw-r--r--userapi/storage/postgres/notifications_table.go219
-rw-r--r--userapi/storage/postgres/pusher_table.go157
-rw-r--r--userapi/storage/postgres/storage.go10
-rw-r--r--userapi/storage/shared/storage.go109
-rw-r--r--userapi/storage/sqlite3/notifications_table.go219
-rw-r--r--userapi/storage/sqlite3/pusher_table.go157
-rw-r--r--userapi/storage/sqlite3/storage.go10
-rw-r--r--userapi/storage/tables/interface.go40
-rw-r--r--userapi/userapi.go57
-rw-r--r--userapi/userapi_test.go6
-rw-r--r--userapi/util/devices.go100
-rw-r--r--userapi/util/notify.go76
92 files changed, 5839 insertions, 193 deletions
diff --git a/.gitignore b/.gitignore
index 092f4501..2a8c2cf5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -54,7 +54,7 @@ dendrite.yaml
*.db
# Log files
-*.log*
+*.log*
# Generated code
cmd/dendrite-demo-yggdrasil/embed/fs*.go
@@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go
# Test dependencies
test/wasm/node_modules
-media_store/
+# Ignore complement folder when running locally
+complement/
+media_store/
diff --git a/build/docker/config/dendrite.yaml b/build/docker/config/dendrite.yaml
index 6d5ebc9f..ebae5013 100644
--- a/build/docker/config/dendrite.yaml
+++ b/build/docker/config/dendrite.yaml
@@ -318,6 +318,17 @@ user_api:
max_idle_conns: 2
conn_max_lifetime: -1
+# Configuration for the Push Server API.
+push_server:
+ internal_api:
+ listen: http://localhost:7782
+ connect: http://localhost:7782
+ database:
+ connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable
+ max_open_conns: 10
+ max_idle_conns: 2
+ conn_max_lifetime: -1
+
# Configuration for Opentracing.
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
# how this works and how to set it up.
diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go
index aa8cc6e6..5ab90ada 100644
--- a/build/gobind-pinecone/monolith.go
+++ b/build/gobind-pinecone/monolith.go
@@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() {
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
- m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
+ m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(m.userAPI)
eduInputAPI := eduserver.NewInternalAPI(
diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go
index 8b9c88f2..3329485a 100644
--- a/build/gobind-yggdrasil/monolith.go
+++ b/build/gobind-yggdrasil/monolith.go
@@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() {
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
- userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
+ userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI(
diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go
index a65f3b70..91847667 100644
--- a/clientapi/clientapi.go
+++ b/clientapi/clientapi.go
@@ -59,6 +59,7 @@ func AddPublicRoutes(
routing.Setup(
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
accountsDB, userAPI, federation,
- syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg,
+ syncProducer, transactionsCache, fsAPI, keyAPI,
+ extRoomsProvider, mscCfg,
)
}
diff --git a/clientapi/producers/syncapi.go b/clientapi/producers/syncapi.go
index 9b1d6b1a..9ab90391 100644
--- a/clientapi/producers/syncapi.go
+++ b/clientapi/producers/syncapi.go
@@ -30,7 +30,7 @@ type SyncAPIProducer struct {
}
// SendData sends account data to the sync API server
-func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error {
+func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
@@ -38,8 +38,9 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
m.Header.Set(jetstream.UserID, userID)
data := eventutil.AccountData{
- RoomID: roomID,
- Type: dataType,
+ RoomID: roomID,
+ Type: dataType,
+ ReadMarker: readMarker,
}
var err error
m.Data, err = json.Marshal(data)
diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go
index 03025f1d..d8e98269 100644
--- a/clientapi/routing/account_data.go
+++ b/clientapi/routing/account_data.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
+ "github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
@@ -127,7 +128,7 @@ func SaveAccountData(
}
// TODO: user API should do this since it's account data
- if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
+ if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError()
}
@@ -138,11 +139,6 @@ func SaveAccountData(
}
}
-type readMarkerJSON struct {
- FullyRead string `json:"m.fully_read"`
- Read string `json:"m.read"`
-}
-
type fullyReadEvent struct {
EventID string `json:"event_id"`
}
@@ -159,7 +155,7 @@ func SaveReadMarker(
return *resErr
}
- var r readMarkerJSON
+ var r eventutil.ReadMarkerJSON
resErr = httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
@@ -189,7 +185,7 @@ func SaveReadMarker(
return util.ErrorResponse(err)
}
- if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil {
+ if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError()
}
diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go
new file mode 100644
index 00000000..ee715d32
--- /dev/null
+++ b/clientapi/routing/notification.go
@@ -0,0 +1,63 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "net/http"
+ "strconv"
+
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// GetNotifications handles /_matrix/client/r0/notifications
+func GetNotifications(
+ req *http.Request, device *userapi.Device,
+ userAPI userapi.UserInternalAPI,
+) util.JSONResponse {
+ var limit int64
+ if limitStr := req.URL.Query().Get("limit"); limitStr != "" {
+ var err error
+ limit, err = strconv.ParseInt(limitStr, 10, 64)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed")
+ return jsonerror.InternalServerError()
+ }
+ }
+
+ var queryRes userapi.QueryNotificationsResponse
+ localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
+ return jsonerror.InternalServerError()
+ }
+ err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
+ Localpart: localpart,
+ From: req.URL.Query().Get("from"),
+ Limit: int(limit),
+ Only: req.URL.Query().Get("only"),
+ }, &queryRes)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
+ return jsonerror.InternalServerError()
+ }
+ util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications))
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: queryRes,
+ }
+}
diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go
index acac60fa..c63412d0 100644
--- a/clientapi/routing/password.go
+++ b/clientapi/routing/password.go
@@ -12,6 +12,7 @@ import (
userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
)
type newPasswordRequest struct {
@@ -37,6 +38,11 @@ func Password(
var r newPasswordRequest
r.LogoutDevices = true
+ logrus.WithFields(logrus.Fields{
+ "sessionId": device.SessionID,
+ "userId": device.UserID,
+ }).Debug("Changing password")
+
// Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
@@ -116,6 +122,15 @@ func Password(
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
+
+ pushersReq := &api.PerformPusherDeletionRequest{
+ Localpart: localpart,
+ SessionID: device.SessionID,
+ }
+ if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
+ return jsonerror.InternalServerError()
+ }
}
// Return a success code.
diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go
new file mode 100644
index 00000000..9d6bef8b
--- /dev/null
+++ b/clientapi/routing/pusher.go
@@ -0,0 +1,114 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package routing
+
+import (
+ "net/http"
+ "net/url"
+
+ "github.com/matrix-org/dendrite/clientapi/httputil"
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+// GetPushers handles /_matrix/client/r0/pushers
+func GetPushers(
+ req *http.Request, device *userapi.Device,
+ userAPI userapi.UserInternalAPI,
+) util.JSONResponse {
+ var queryRes userapi.QueryPushersResponse
+ localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
+ return jsonerror.InternalServerError()
+ }
+ err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
+ Localpart: localpart,
+ }, &queryRes)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
+ return jsonerror.InternalServerError()
+ }
+ for i := range queryRes.Pushers {
+ queryRes.Pushers[i].SessionID = 0
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: queryRes,
+ }
+}
+
+// SetPusher handles /_matrix/client/r0/pushers/set
+// This endpoint allows the creation, modification and deletion of pushers for this user ID.
+// The behaviour of this endpoint varies depending on the values in the JSON body.
+func SetPusher(
+ req *http.Request, device *userapi.Device,
+ userAPI userapi.UserInternalAPI,
+) util.JSONResponse {
+ localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
+ return jsonerror.InternalServerError()
+ }
+ body := userapi.PerformPusherSetRequest{}
+ if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil {
+ return *resErr
+ }
+ if len(body.AppID) > 64 {
+ return invalidParam("length of app_id must be no more than 64 characters")
+ }
+ if len(body.PushKey) > 512 {
+ return invalidParam("length of pushkey must be no more than 512 bytes")
+ }
+ uInt := body.Data["url"]
+ if uInt != nil {
+ u, ok := uInt.(string)
+ if !ok {
+ return invalidParam("url must be string")
+ }
+ if u != "" {
+ var pushUrl *url.URL
+ pushUrl, err = url.Parse(u)
+ if err != nil {
+ return invalidParam("malformed url passed")
+ }
+ if pushUrl.Scheme != "https" {
+ return invalidParam("only https scheme is allowed")
+ }
+ }
+
+ }
+ body.Localpart = localpart
+ body.SessionID = device.SessionID
+ err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed")
+ return jsonerror.InternalServerError()
+ }
+
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: struct{}{},
+ }
+}
+
+func invalidParam(msg string) util.JSONResponse {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidParam(msg),
+ }
+}
diff --git a/clientapi/routing/pushrules.go b/clientapi/routing/pushrules.go
new file mode 100644
index 00000000..81a33b25
--- /dev/null
+++ b/clientapi/routing/pushrules.go
@@ -0,0 +1,386 @@
+package routing
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "reflect"
+
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/internal/pushrules"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/util"
+)
+
+func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse {
+ if eerr, ok := err.(*jsonerror.MatrixError); ok {
+ var status int
+ switch eerr.ErrCode {
+ case "M_INVALID_ARGUMENT_VALUE":
+ status = http.StatusBadRequest
+ case "M_NOT_FOUND":
+ status = http.StatusNotFound
+ default:
+ status = http.StatusInternalServerError
+ }
+ return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err)
+ }
+ util.GetLogger(ctx).WithError(err).Errorf(msg, args...)
+ return jsonerror.InternalServerError()
+}
+
+func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRulesJSON failed")
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: ruleSets,
+ }
+}
+
+func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRulesJSON failed")
+ }
+ ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
+ if ruleSet == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: ruleSet,
+ }
+}
+
+func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRules failed")
+ }
+ ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
+ if ruleSet == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
+ }
+ rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
+ if rulesPtr == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: *rulesPtr,
+ }
+}
+
+func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRules failed")
+ }
+ ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
+ if ruleSet == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
+ }
+ rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
+ if rulesPtr == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
+ }
+ i := pushRuleIndexByID(*rulesPtr, ruleID)
+ if i < 0 {
+ return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: (*rulesPtr)[i],
+ }
+}
+
+func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ var newRule pushrules.Rule
+ if err := json.NewDecoder(body).Decode(&newRule); err != nil {
+ return errorResponse(ctx, err, "JSON Decode failed")
+ }
+ newRule.RuleID = ruleID
+
+ errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule)
+ if len(errs) > 0 {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
+ }
+
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRules failed")
+ }
+ ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
+ if ruleSet == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
+ }
+ rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
+ if rulesPtr == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
+ }
+ i := pushRuleIndexByID(*rulesPtr, ruleID)
+ if i >= 0 && afterRuleID == "" && beforeRuleID == "" {
+ // Modify rule at the same index.
+
+ // TODO: The spec does not say what to do in this case, but
+ // this feels reasonable.
+ *((*rulesPtr)[i]) = newRule
+ util.GetLogger(ctx).Infof("Modified existing push rule at %d", i)
+ } else {
+ if i >= 0 {
+ // Delete old rule.
+ *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
+ util.GetLogger(ctx).Infof("Deleted old push rule at %d", i)
+ } else {
+ // SPEC: When creating push rules, they MUST be enabled by default.
+ //
+ // TODO: it's unclear if we must reject disabled rules, or force
+ // the value to true. Sytests fail if we don't force it.
+ newRule.Enabled = true
+ }
+
+ // Add new rule.
+ i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
+ if err != nil {
+ return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
+ }
+
+ *rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...)
+ util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
+ }
+
+ if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
+ return errorResponse(ctx, err, "putPushRules failed")
+ }
+
+ return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
+}
+
+func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRules failed")
+ }
+ ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
+ if ruleSet == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
+ }
+ rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
+ if rulesPtr == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
+ }
+ i := pushRuleIndexByID(*rulesPtr, ruleID)
+ if i < 0 {
+ return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
+ }
+
+ *rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
+
+ if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
+ return errorResponse(ctx, err, "putPushRules failed")
+ }
+
+ return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
+}
+
+func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ attrGet, err := pushRuleAttrGetter(attr)
+ if err != nil {
+ return errorResponse(ctx, err, "pushRuleAttrGetter failed")
+ }
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRules failed")
+ }
+ ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
+ if ruleSet == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
+ }
+ rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
+ if rulesPtr == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
+ }
+ i := pushRuleIndexByID(*rulesPtr, ruleID)
+ if i < 0 {
+ return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
+ }
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: map[string]interface{}{
+ attr: attrGet((*rulesPtr)[i]),
+ },
+ }
+}
+
+func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
+ var newPartialRule pushrules.Rule
+ if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.BadJSON(err.Error()),
+ }
+ }
+ if newPartialRule.Actions == nil {
+ // This ensures json.Marshal encodes the empty list as [] rather than null.
+ newPartialRule.Actions = []*pushrules.Action{}
+ }
+
+ attrGet, err := pushRuleAttrGetter(attr)
+ if err != nil {
+ return errorResponse(ctx, err, "pushRuleAttrGetter failed")
+ }
+ attrSet, err := pushRuleAttrSetter(attr)
+ if err != nil {
+ return errorResponse(ctx, err, "pushRuleAttrSetter failed")
+ }
+
+ ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
+ if err != nil {
+ return errorResponse(ctx, err, "queryPushRules failed")
+ }
+ ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
+ if ruleSet == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
+ }
+ rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
+ if rulesPtr == nil {
+ return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
+ }
+ i := pushRuleIndexByID(*rulesPtr, ruleID)
+ if i < 0 {
+ return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
+ }
+
+ if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
+ attrSet((*rulesPtr)[i], &newPartialRule)
+
+ if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
+ return errorResponse(ctx, err, "putPushRules failed")
+ }
+ }
+
+ return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
+}
+
+func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) {
+ var res userapi.QueryPushRulesResponse
+ if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
+ return nil, err
+ }
+ return res.RuleSets, nil
+}
+
+func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error {
+ req := userapi.PerformPushRulesPutRequest{
+ UserID: userID,
+ RuleSets: ruleSets,
+ }
+ var res struct{}
+ if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
+ return err
+ }
+ return nil
+}
+
+func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
+ switch scope {
+ case pushrules.GlobalScope:
+ return &ruleSets.Global
+ default:
+ return nil
+ }
+}
+
+func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule {
+ switch kind {
+ case pushrules.OverrideKind:
+ return &ruleSet.Override
+ case pushrules.ContentKind:
+ return &ruleSet.Content
+ case pushrules.RoomKind:
+ return &ruleSet.Room
+ case pushrules.SenderKind:
+ return &ruleSet.Sender
+ case pushrules.UnderrideKind:
+ return &ruleSet.Underride
+ default:
+ return nil
+ }
+}
+
+func pushRuleIndexByID(rules []*pushrules.Rule, id string) int {
+ for i, rule := range rules {
+ if rule.RuleID == id {
+ return i
+ }
+ }
+ return -1
+}
+
+func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) {
+ switch attr {
+ case "actions":
+ return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil
+ case "enabled":
+ return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil
+ default:
+ return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
+ }
+}
+
+func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) {
+ switch attr {
+ case "actions":
+ return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil
+ case "enabled":
+ return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil
+ default:
+ return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
+ }
+}
+
+func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) {
+ var i int
+
+ if afterID != "" {
+ for ; i < len(rules); i++ {
+ if rules[i].RuleID == afterID {
+ break
+ }
+ }
+ if i == len(rules) {
+ return 0, jsonerror.NotFound("after: rule ID not found")
+ }
+ if rules[i].Default {
+ return 0, jsonerror.NotFound("after: rule ID must not be a default rule")
+ }
+ // We stopped on the "after" match to differentiate
+ // not-found from is-last-entry. Now we move to the earliest
+ // insertion point.
+ i++
+ }
+
+ if beforeID != "" {
+ for ; i < len(rules); i++ {
+ if rules[i].RuleID == beforeID {
+ break
+ }
+ }
+ if i == len(rules) {
+ return 0, jsonerror.NotFound("before: rule ID not found")
+ }
+ if rules[i].Default {
+ return 0, jsonerror.NotFound("before: rule ID must not be a default rule")
+ }
+ }
+
+ // UNSPEC: The spec does not say what to do if no after/before is
+ // given. Sytest fails if it doesn't go first.
+ return i, nil
+}
diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go
index c683cc94..83294b18 100644
--- a/clientapi/routing/room_tagging.go
+++ b/clientapi/routing/room_tagging.go
@@ -98,7 +98,7 @@ func PutTag(
return jsonerror.InternalServerError()
}
- if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
+ if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
@@ -151,7 +151,7 @@ func DeleteTag(
}
// TODO: user API should do this since it's account data
- if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
+ if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index d75f58b8..d22fbd80 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -16,7 +16,6 @@ package routing
import (
"context"
- "encoding/json"
"net/http"
"strings"
@@ -561,25 +560,142 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
+ // Push rules
+
+ v3mux.Handle("/pushrules",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("missing trailing slash"),
+ }
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
v3mux.Handle("/pushrules/",
- httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
- // TODO: Implement push rules API
- res := json.RawMessage(`{
- "global": {
- "content": [],
- "override": [],
- "room": [],
- "sender": [],
- "underride": []
- }
- }`)
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return GetAllPushRules(req.Context(), device, userAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushrules/",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
- Code: http.StatusOK,
- JSON: &res,
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"),
+ }
+ }),
+ ).Methods(http.MethodPut)
+
+ v3mux.Handle("/pushrules/{scope}/",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushrules/{scope}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"),
+ }
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushrules/{scope:[^/]+/?}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"),
+ }
+ }),
+ ).Methods(http.MethodPut)
+
+ v3mux.Handle("/pushrules/{scope}/{kind}/",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushrules/{scope}/{kind}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"),
+ }
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"),
+ }
+ }),
+ ).Methods(http.MethodPut)
+
+ v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
}
+ return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
+ v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ if r := rateLimits.Limit(req); r != nil {
+ return *r
+ }
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ query := req.URL.Query()
+ return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI)
+ }),
+ ).Methods(http.MethodPut)
+
+ v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
+ }),
+ ).Methods(http.MethodDelete)
+
+ v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
+ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI)
+ }),
+ ).Methods(http.MethodPut)
+
// Element user settings
v3mux.Handle("/profile/{userID}",
@@ -885,6 +1001,27 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
+ v3mux.Handle("/notifications",
+ httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return GetNotifications(req, device, userAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushers",
+ httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return GetPushers(req, device, userAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ v3mux.Handle("/pushers/set",
+ httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ if r := rateLimits.Limit(req); r != nil {
+ return *r
+ }
+ return SetPusher(req, device, userAPI)
+ }),
+ ).Methods(http.MethodPost, http.MethodOptions)
+
// Stub implementations for sytest
v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go
index 78536901..8ce64191 100644
--- a/cmd/dendrite-demo-libp2p/main.go
+++ b/cmd/dendrite-demo-libp2p/main.go
@@ -144,12 +144,14 @@ func main() {
accountDB := base.Base.CreateAccountsDB()
federation := createFederationClient(base)
keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation)
- userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
- keyAPI.SetUserAPI(userAPI)
rsAPI := roomserver.NewInternalAPI(
&base.Base,
)
+
+ userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient())
+ keyAPI.SetUserAPI(userAPI)
+
eduInputAPI := eduserver.NewInternalAPI(
&base.Base, cache.New(), userAPI,
)
diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go
index 5810a7f1..45f18698 100644
--- a/cmd/dendrite-demo-pinecone/main.go
+++ b/cmd/dendrite-demo-pinecone/main.go
@@ -187,7 +187,7 @@ func main() {
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
- userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
+ userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI(
diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go
index d16f0e9e..b7e30ba2 100644
--- a/cmd/dendrite-demo-yggdrasil/main.go
+++ b/cmd/dendrite-demo-yggdrasil/main.go
@@ -111,14 +111,15 @@ func main() {
keyRing := serverKeyAPI.KeyRing()
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
- userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
- keyAPI.SetUserAPI(userAPI)
rsComponent := roomserver.NewInternalAPI(
base,
)
rsAPI := rsComponent
+ userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
+ keyAPI.SetUserAPI(userAPI)
+
eduInputAPI := eduserver.NewInternalAPI(
base, cache.New(), userAPI,
)
diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go
index bb268520..3b952504 100644
--- a/cmd/dendrite-monolith-server/main.go
+++ b/cmd/dendrite-monolith-server/main.go
@@ -106,7 +106,8 @@ func main() {
keyAPI = base.KeyServerHTTPClient()
}
- userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI)
+ pgClient := base.PushGatewayHTTPClient()
+ userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
userAPI := userImpl
if base.UseHTTPAPIs {
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go
index f147cda1..f1fa379c 100644
--- a/cmd/dendrite-polylith-multi/personalities/userapi.go
+++ b/cmd/dendrite-polylith-multi/personalities/userapi.go
@@ -23,7 +23,11 @@ import (
func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
accountDB := base.CreateAccountsDB()
- userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient())
+ userAPI := userapi.NewInternalAPI(
+ base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices,
+ base.KeyServerHTTPClient(), base.RoomserverHTTPClient(),
+ base.PushGatewayHTTPClient(),
+ )
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)
diff --git a/cmd/dendritejs-pinecone/main.go b/cmd/dendritejs-pinecone/main.go
index 664f644f..407081f5 100644
--- a/cmd/dendritejs-pinecone/main.go
+++ b/cmd/dendritejs-pinecone/main.go
@@ -184,13 +184,15 @@ func startup() {
accountDB := base.CreateAccountsDB()
federation := conn.CreateFederationClient(base, pSessions)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
- userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
- keyAPI.SetUserAPI(userAPI)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
rsAPI := roomserver.NewInternalAPI(base)
+
+ userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
+ keyAPI.SetUserAPI(userAPI)
+
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI,
diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go
index 0ea41b4c..37cbb12d 100644
--- a/cmd/dendritejs/main.go
+++ b/cmd/dendritejs/main.go
@@ -212,6 +212,8 @@ func main() {
rsAPI.SetFederationAPI(fedSenderAPI, keyRing)
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
+ psAPI := pushserver.NewInternalAPI(base)
+
monolith := setup.Monolith{
Config: base.Cfg,
AccountDB: accountDB,
@@ -225,6 +227,7 @@ func main() {
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
+ PushserverAPI: psAPI,
//ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: p2pPublicRoomProvider,
}
diff --git a/dendrite-config.yaml b/dendrite-config.yaml
index 533b5c95..0236851c 100644
--- a/dendrite-config.yaml
+++ b/dendrite-config.yaml
@@ -6,7 +6,7 @@
#
# At a minimum, to get started, you will need to update the settings in the
# "global" section for your deployment, and you will need to check that the
-# database "connection_string" line in each component section is correct.
+# database "connection_string" line in each component section is correct.
#
# Each component with a "database" section can accept the following formats
# for "connection_string":
@@ -21,13 +21,13 @@
# small number of users and likely will perform worse still with a higher volume
# of users.
#
-# The "max_open_conns" and "max_idle_conns" settings configure the maximum
+# The "max_open_conns" and "max_idle_conns" settings configure the maximum
# number of open/idle database connections. The value 0 will use the database
# engine default, and a negative value will use unlimited connections. The
# "conn_max_lifetime" option controls the maximum length of time a database
# connection can be idle in seconds - a negative value is unlimited.
-# The version of the configuration file.
+# The version of the configuration file.
version: 2
# Global Matrix configuration. This configuration applies to all components.
@@ -61,8 +61,8 @@ global:
# Lists of domains that the server will trust as identity servers to verify third
# party identifiers such as phone numbers and email addresses.
trusted_third_party_id_servers:
- - matrix.org
- - vector.im
+ - matrix.org
+ - vector.im
# Disables federation. Dendrite will not be able to make any outbound HTTP requests
# to other servers and the federation API will not be exposed.
@@ -87,14 +87,14 @@ global:
# in monolith mode. It is required to specify the address of at least one
# NATS Server node if running in polylith mode.
addresses:
- # - localhost:4222
+ # - localhost:4222
# Keep all NATS streams in memory, rather than persisting it to the storage
# path below. This option is present primarily for integration testing and
# should not be used on a real world Dendrite deployment.
in_memory: false
- # Persistent directory to store JetStream streams in. This directory
+ # Persistent directory to store JetStream streams in. This directory
# should be preserved across Dendrite restarts.
storage_path: ./
@@ -126,7 +126,7 @@ global:
# Configuration for the Appservice API.
app_service_api:
internal_api:
- listen: http://localhost:7777 # Only used in polylith deployments
+ listen: http://localhost:7777 # Only used in polylith deployments
connect: http://localhost:7777 # Only used in polylith deployments
database:
connection_string: file:appservice.db
@@ -145,7 +145,7 @@ app_service_api:
# Configuration for the Client API.
client_api:
internal_api:
- listen: http://localhost:7771 # Only used in polylith deployments
+ listen: http://localhost:7771 # Only used in polylith deployments
connect: http://localhost:7771 # Only used in polylith deployments
external_api:
listen: http://[::]:8071
@@ -165,13 +165,13 @@ client_api:
# Whether to require reCAPTCHA for registration.
enable_registration_captcha: false
- # Settings for ReCAPTCHA.
+ # Settings for ReCAPTCHA.
recaptcha_public_key: ""
recaptcha_private_key: ""
recaptcha_bypass_secret: ""
recaptcha_siteverify_api: ""
- # TURN server information that this homeserver should send to clients.
+ # TURN server information that this homeserver should send to clients.
turn:
turn_user_lifetime: ""
turn_uris: []
@@ -180,7 +180,7 @@ client_api:
turn_password: ""
# Settings for rate-limited endpoints. Rate limiting will kick in after the
- # threshold number of "slots" have been taken by requests from a specific
+ # threshold number of "slots" have been taken by requests from a specific
# host. Each "slot" will be released after the cooloff time in milliseconds.
rate_limiting:
enabled: true
@@ -190,13 +190,13 @@ client_api:
# Configuration for the EDU server.
edu_server:
internal_api:
- listen: http://localhost:7778 # Only used in polylith deployments
+ listen: http://localhost:7778 # Only used in polylith deployments
connect: http://localhost:7778 # Only used in polylith deployments
# Configuration for the Federation API.
federation_api:
internal_api:
- listen: http://localhost:7772 # Only used in polylith deployments
+ listen: http://localhost:7772 # Only used in polylith deployments
connect: http://localhost:7772 # Only used in polylith deployments
external_api:
listen: http://[::]:8072
@@ -224,12 +224,12 @@ federation_api:
# be required to satisfy key requests for servers that are no longer online when
# joining some rooms.
key_perspectives:
- - server_name: matrix.org
- keys:
- - key_id: ed25519:auto
- public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
- - key_id: ed25519:a_RXGa
- public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
+ - server_name: matrix.org
+ keys:
+ - key_id: ed25519:auto
+ public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw
+ - key_id: ed25519:a_RXGa
+ public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ
# This option will control whether Dendrite will prefer to look up keys directly
# or whether it should try perspective servers first, using direct fetches as a
@@ -239,7 +239,7 @@ federation_api:
# Configuration for the Key Server (for end-to-end encryption).
key_server:
internal_api:
- listen: http://localhost:7779 # Only used in polylith deployments
+ listen: http://localhost:7779 # Only used in polylith deployments
connect: http://localhost:7779 # Only used in polylith deployments
database:
connection_string: file:keyserver.db
@@ -250,7 +250,7 @@ key_server:
# Configuration for the Media API.
media_api:
internal_api:
- listen: http://localhost:7774 # Only used in polylith deployments
+ listen: http://localhost:7774 # Only used in polylith deployments
connect: http://localhost:7774 # Only used in polylith deployments
external_api:
listen: http://[::]:8074
@@ -276,15 +276,15 @@ media_api:
# A list of thumbnail sizes to be generated for media content.
thumbnail_sizes:
- - width: 32
- height: 32
- method: crop
- - width: 96
- height: 96
- method: crop
- - width: 640
- height: 480
- method: scale
+ - width: 32
+ height: 32
+ method: crop
+ - width: 96
+ height: 96
+ method: crop
+ - width: 640
+ height: 480
+ method: scale
# Configuration for experimental MSC's
mscs:
@@ -302,7 +302,7 @@ mscs:
# Configuration for the Room Server.
room_server:
internal_api:
- listen: http://localhost:7770 # Only used in polylith deployments
+ listen: http://localhost:7770 # Only used in polylith deployments
connect: http://localhost:7770 # Only used in polylith deployments
database:
connection_string: file:roomserver.db
@@ -313,7 +313,7 @@ room_server:
# Configuration for the Sync API.
sync_api:
internal_api:
- listen: http://localhost:7773 # Only used in polylith deployments
+ listen: http://localhost:7773 # Only used in polylith deployments
connect: http://localhost:7773 # Only used in polylith deployments
external_api:
listen: http://[::]:8073
@@ -338,16 +338,16 @@ user_api:
# This value can be low if performing tests or on embedded Dendrite instances (e.g WASM builds)
# bcrypt_cost: 10
internal_api:
- listen: http://localhost:7781 # Only used in polylith deployments
+ listen: http://localhost:7781 # Only used in polylith deployments
connect: http://localhost:7781 # Only used in polylith deployments
account_database:
connection_string: file:userapi_accounts.db
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
- # The length of time that a token issued for a relying party from
+ # The length of time that a token issued for a relying party from
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
- # is considered to be valid in milliseconds.
+ # is considered to be valid in milliseconds.
# The default lifetime is 3600000ms (60 minutes).
# openid_token_lifetime_ms: 3600000
@@ -369,10 +369,10 @@ tracing:
# Logging configuration
logging:
-- type: std
- level: info
-- type: file
- # The logging level, must be one of debug, info, warn, error, fatal, panic.
- level: info
- params:
- path: ./logs
+ - type: std
+ level: info
+ - type: file
+ # The logging level, must be one of debug, info, warn, error, fatal, panic.
+ level: info
+ params:
+ path: ./logs
diff --git a/go.mod b/go.mod
index dbcae5d5..525950da 100644
--- a/go.mod
+++ b/go.mod
@@ -18,6 +18,7 @@ require (
github.com/frankban/quicktest v1.14.0 // indirect
github.com/getsentry/sentry-go v0.12.0
github.com/gologme/log v1.3.0
+ github.com/google/go-cmp v0.5.6
github.com/google/uuid v1.2.0
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
diff --git a/internal/eventutil/types.go b/internal/eventutil/types.go
index 6d119ce6..17861d6c 100644
--- a/internal/eventutil/types.go
+++ b/internal/eventutil/types.go
@@ -26,8 +26,30 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
// AccountData represents account data sent from the client API server to the
// sync API server
type AccountData struct {
+ RoomID string `json:"room_id"`
+ Type string `json:"type"`
+ ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
+}
+
+type ReadMarkerJSON struct {
+ FullyRead string `json:"m.fully_read"`
+ Read string `json:"m.read"`
+}
+
+// NotificationData contains statistics about notifications, sent from
+// the Push Server to the Sync API server.
+type NotificationData struct {
+ // RoomID identifies the scope of the statistics, together with
+ // MXID (which is encoded in the Kafka key).
RoomID string `json:"room_id"`
- Type string `json:"type"`
+
+ // HighlightCount is the number of unread notifications with the
+ // highlight tweak.
+ UnreadHighlightCount int `json:"unread_highlight_count"`
+
+ // UnreadNotificationCount is the total number of unread
+ // notifications.
+ UnreadNotificationCount int `json:"unread_notification_count"`
}
// ProfileResponse is a struct containing all known user profile data
diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go
new file mode 100644
index 00000000..49907cee
--- /dev/null
+++ b/internal/pushgateway/client.go
@@ -0,0 +1,66 @@
+package pushgateway
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+
+ "github.com/opentracing/opentracing-go"
+)
+
+type httpClient struct {
+ hc *http.Client
+}
+
+// NewHTTPClient creates a new Push Gateway client.
+func NewHTTPClient(disableTLSValidation bool) Client {
+ hc := &http.Client{
+ Timeout: 30 * time.Second,
+ Transport: &http.Transport{
+ DisableKeepAlives: true,
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: disableTLSValidation,
+ },
+ },
+ }
+ return &httpClient{hc: hc}
+}
+
+func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "Notify")
+ defer span.Finish()
+
+ body, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+ hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return err
+ }
+ hreq.Header.Set("Content-Type", "application/json")
+
+ hresp, err := h.hc.Do(hreq)
+ if err != nil {
+ return err
+ }
+
+ //nolint:errcheck
+ defer hresp.Body.Close()
+
+ if hresp.StatusCode == http.StatusOK {
+ return json.NewDecoder(hresp.Body).Decode(resp)
+ }
+
+ var errorBody struct {
+ Message string `json:"message"`
+ }
+ if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil {
+ return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message)
+ }
+ return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url)
+}
diff --git a/internal/pushgateway/pushgateway.go b/internal/pushgateway/pushgateway.go
new file mode 100644
index 00000000..88c326eb
--- /dev/null
+++ b/internal/pushgateway/pushgateway.go
@@ -0,0 +1,62 @@
+package pushgateway
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// A Client is how interactions with a Push Gateway is done.
+type Client interface {
+ // Notify sends a notification to the gateway at the given URL.
+ Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error
+}
+
+type NotifyRequest struct {
+ Notification Notification `json:"notification"` // Required
+}
+
+type NotifyResponse struct {
+ // Rejected is the list of device push keys that were rejected
+ // during the push. The caller should remove the push keys so they
+ // are not used again.
+ Rejected []string `json:"rejected"` // Required
+}
+
+type Notification struct {
+ Content json.RawMessage `json:"content,omitempty"`
+ Counts *Counts `json:"counts,omitempty"`
+ Devices []*Device `json:"devices"` // Required
+ EventID string `json:"event_id,omitempty"`
+ ID string `json:"id,omitempty"` // Deprecated name for EventID.
+ Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest.
+ Prio Prio `json:"prio,omitempty"`
+ RoomAlias string `json:"room_alias,omitempty"`
+ RoomID string `json:"room_id,omitempty"`
+ RoomName string `json:"room_name,omitempty"`
+ Sender string `json:"sender,omitempty"`
+ SenderDisplayName string `json:"sender_display_name,omitempty"`
+ Type string `json:"type,omitempty"`
+ UserIsTarget bool `json:"user_is_target,omitempty"`
+}
+
+type Counts struct {
+ MissedCalls int `json:"missed_calls,omitempty"`
+ Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it.
+}
+
+type Device struct {
+ AppID string `json:"app_id"` // Required
+ Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
+ PushKey string `json:"pushkey"` // Required
+ PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
+ Tweaks map[string]interface{} `json:"tweaks,omitempty"`
+}
+
+type Prio string
+
+const (
+ HighPrio Prio = "high"
+ LowPrio Prio = "low"
+)
diff --git a/internal/pushrules/action.go b/internal/pushrules/action.go
new file mode 100644
index 00000000..c7b8cec8
--- /dev/null
+++ b/internal/pushrules/action.go
@@ -0,0 +1,102 @@
+package pushrules
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+)
+
+// An Action is (part of) an outcome of a rule. There are
+// (unofficially) terminal actions, and modifier actions.
+type Action struct {
+ // Kind is the type of action. Has custom encoding in JSON.
+ Kind ActionKind `json:"-"`
+
+ // Tweak is the property to tweak. Has custom encoding in JSON.
+ Tweak TweakKey `json:"-"`
+
+ // Value is some value interpreted according to Kind and Tweak.
+ Value interface{} `json:"value"`
+}
+
+func (a *Action) MarshalJSON() ([]byte, error) {
+ if a.Tweak == UnknownTweak && a.Value == nil {
+ return json.Marshal(a.Kind)
+ }
+
+ if a.Kind != SetTweakAction {
+ return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind)
+ }
+
+ m := map[string]interface{}{
+ string(a.Kind): a.Tweak,
+ }
+ if a.Value != nil {
+ m["value"] = a.Value
+ }
+
+ return json.Marshal(m)
+}
+
+func (a *Action) UnmarshalJSON(bs []byte) error {
+ if bytes.HasPrefix(bs, []byte("\"")) {
+ return json.Unmarshal(bs, &a.Kind)
+ }
+
+ var raw struct {
+ SetTweak TweakKey `json:"set_tweak"`
+ Value interface{} `json:"value"`
+ }
+ if err := json.Unmarshal(bs, &raw); err != nil {
+ return err
+ }
+ if raw.SetTweak == UnknownTweak {
+ return fmt.Errorf("got unknown action JSON: %s", string(bs))
+ }
+ a.Kind = SetTweakAction
+ a.Tweak = raw.SetTweak
+ if raw.Value != nil {
+ a.Value = raw.Value
+ }
+
+ return nil
+}
+
+// ActionKind is the primary discriminator for actions.
+type ActionKind string
+
+const (
+ UnknownAction ActionKind = ""
+
+ // NotifyAction indicates the clients should show a notification.
+ NotifyAction ActionKind = "notify"
+
+ // DontNotifyAction indicates the clients should not show a notification.
+ DontNotifyAction ActionKind = "dont_notify"
+
+ // CoalesceAction tells the clients to show a notification, and
+ // tells both servers and clients that multiple events can be
+ // coalesced into a single notification. The behaviour is
+ // implementation-specific.
+ CoalesceAction ActionKind = "coalesce"
+
+ // SetTweakAction uses the Tweak and Value fields to add a
+ // tweak. Multiple SetTweakAction can be provided in a rule,
+ // combined with NotifyAction or CoalesceAction.
+ SetTweakAction ActionKind = "set_tweak"
+)
+
+// A TweakKey describes a property to be modified/tweaked for events
+// that match the rule.
+type TweakKey string
+
+const (
+ UnknownTweak TweakKey = ""
+
+ // SoundTweak describes which sound to play. Using "default" means
+ // "enable sound".
+ SoundTweak TweakKey = "sound"
+
+ // HighlightTweak asks the clients to highlight the conversation.
+ HighlightTweak TweakKey = "highlight"
+)
diff --git a/internal/pushrules/action_test.go b/internal/pushrules/action_test.go
new file mode 100644
index 00000000..72db9c99
--- /dev/null
+++ b/internal/pushrules/action_test.go
@@ -0,0 +1,39 @@
+package pushrules
+
+import (
+ "encoding/json"
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestActionJSON(t *testing.T) {
+ tsts := []struct {
+ Want Action
+ }{
+ {Action{Kind: NotifyAction}},
+ {Action{Kind: DontNotifyAction}},
+ {Action{Kind: CoalesceAction}},
+ {Action{Kind: SetTweakAction}},
+
+ {Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}},
+ {Action{Kind: SetTweakAction, Tweak: HighlightTweak}},
+ {Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}},
+ }
+ for _, tst := range tsts {
+ t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) {
+ bs, err := json.Marshal(&tst.Want)
+ if err != nil {
+ t.Fatalf("Marshal failed: %v", err)
+ }
+ var got Action
+ if err := json.Unmarshal(bs, &got); err != nil {
+ t.Fatalf("Unmarshal failed: %v", err)
+ }
+ if diff := cmp.Diff(tst.Want, got); diff != "" {
+ t.Errorf("+got -want:\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go
new file mode 100644
index 00000000..2d9773c0
--- /dev/null
+++ b/internal/pushrules/condition.go
@@ -0,0 +1,49 @@
+package pushrules
+
+// A Condition dictates extra conditions for a matching rules. See
+// ConditionKind.
+type Condition struct {
+ // Kind is the primary discriminator for the condition
+ // type. Required.
+ Kind ConditionKind `json:"kind"`
+
+ // Key indicates the dot-separated path of Event fields to
+ // match. Required for EventMatchCondition and
+ // SenderNotificationPermissionCondition.
+ Key string `json:"key,omitempty"`
+
+ // Pattern indicates the value pattern that must match. Required
+ // for EventMatchCondition.
+ Pattern string `json:"pattern,omitempty"`
+
+ // Is indicates the condition that must be fulfilled. Required for
+ // RoomMemberCountCondition.
+ Is string `json:"is,omitempty"`
+}
+
+// ConditionKind represents a kind of condition.
+//
+// SPEC: Unrecognised conditions MUST NOT match any events,
+// effectively making the push rule disabled.
+type ConditionKind string
+
+const (
+ UnknownCondition ConditionKind = ""
+
+ // EventMatchCondition indicates the condition looks for a key
+ // path and matches a pattern. How paths that don't reference a
+ // simple value match against rules is implementation-specific.
+ EventMatchCondition ConditionKind = "event_match"
+
+ // ContainsDisplayNameCondition indicates the current user's
+ // display name must be found in the content body.
+ ContainsDisplayNameCondition ConditionKind = "contains_display_name"
+
+ // RoomMemberCountCondition matches a simple arithmetic comparison
+ // against the total number of members in a room.
+ RoomMemberCountCondition ConditionKind = "room_member_count"
+
+ // SenderNotificationPermissionCondition compares power level for
+ // the sender in the event's room.
+ SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission"
+)
diff --git a/internal/pushrules/default.go b/internal/pushrules/default.go
new file mode 100644
index 00000000..99698551
--- /dev/null
+++ b/internal/pushrules/default.go
@@ -0,0 +1,23 @@
+package pushrules
+
+import (
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// DefaultAccountRuleSets is the complete set of default push rules
+// for an account.
+func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets {
+ return &AccountRuleSets{
+ Global: *DefaultGlobalRuleSet(localpart, serverName),
+ }
+}
+
+// DefaultGlobalRuleSet returns the default ruleset for a given (fully
+// qualified) MXID.
+func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet {
+ return &RuleSet{
+ Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)),
+ Content: defaultContentRules(localpart),
+ Underride: defaultUnderrideRules,
+ }
+}
diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go
new file mode 100644
index 00000000..158afd18
--- /dev/null
+++ b/internal/pushrules/default_content.go
@@ -0,0 +1,33 @@
+package pushrules
+
+func defaultContentRules(localpart string) []*Rule {
+ return []*Rule{
+ mRuleContainsUserNameDefinition(localpart),
+ }
+}
+
+const (
+ MRuleContainsUserName = ".m.rule.contains_user_name"
+)
+
+func mRuleContainsUserNameDefinition(localpart string) *Rule {
+ return &Rule{
+ RuleID: MRuleContainsUserName,
+ Default: true,
+ Enabled: true,
+ Pattern: localpart,
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "default",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: true,
+ },
+ },
+ }
+}
diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go
new file mode 100644
index 00000000..6f66fd66
--- /dev/null
+++ b/internal/pushrules/default_override.go
@@ -0,0 +1,165 @@
+package pushrules
+
+func defaultOverrideRules(userID string) []*Rule {
+ return []*Rule{
+ &mRuleMasterDefinition,
+ &mRuleSuppressNoticesDefinition,
+ mRuleInviteForMeDefinition(userID),
+ &mRuleMemberEventDefinition,
+ &mRuleContainsDisplayNameDefinition,
+ &mRuleTombstoneDefinition,
+ &mRuleRoomNotifDefinition,
+ }
+}
+
+const (
+ MRuleMaster = ".m.rule.master"
+ MRuleSuppressNotices = ".m.rule.suppress_notices"
+ MRuleInviteForMe = ".m.rule.invite_for_me"
+ MRuleMemberEvent = ".m.rule.member_event"
+ MRuleContainsDisplayName = ".m.rule.contains_display_name"
+ MRuleTombstone = ".m.rule.tombstone"
+ MRuleRoomNotif = ".m.rule.roomnotif"
+)
+
+var (
+ mRuleMasterDefinition = Rule{
+ RuleID: MRuleMaster,
+ Default: true,
+ Enabled: false,
+ Conditions: []*Condition{},
+ Actions: []*Action{{Kind: DontNotifyAction}},
+ }
+ mRuleSuppressNoticesDefinition = Rule{
+ RuleID: MRuleSuppressNotices,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "content.msgtype",
+ Pattern: "m.notice",
+ },
+ },
+ Actions: []*Action{{Kind: DontNotifyAction}},
+ }
+ mRuleMemberEventDefinition = Rule{
+ RuleID: MRuleMemberEvent,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.member",
+ },
+ },
+ Actions: []*Action{{Kind: DontNotifyAction}},
+ }
+ mRuleContainsDisplayNameDefinition = Rule{
+ RuleID: MRuleContainsDisplayName,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}},
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "default",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: true,
+ },
+ },
+ }
+ mRuleTombstoneDefinition = Rule{
+ RuleID: MRuleTombstone,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.tombstone",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "state_key",
+ Pattern: "",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleRoomNotifDefinition = Rule{
+ RuleID: MRuleRoomNotif,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "content.body",
+ Pattern: "@room",
+ },
+ {
+ Kind: SenderNotificationPermissionCondition,
+ Key: "room",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+)
+
+func mRuleInviteForMeDefinition(userID string) *Rule {
+ return &Rule{
+ RuleID: MRuleInviteForMe,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.member",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "content.membership",
+ Pattern: "invite",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "state_key",
+ Pattern: userID,
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "default",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+}
diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go
new file mode 100644
index 00000000..de72bd52
--- /dev/null
+++ b/internal/pushrules/default_underride.go
@@ -0,0 +1,119 @@
+package pushrules
+
+const (
+ MRuleCall = ".m.rule.call"
+ MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one"
+ MRuleRoomOneToOne = ".m.rule.room_one_to_one"
+ MRuleMessage = ".m.rule.message"
+ MRuleEncrypted = ".m.rule.encrypted"
+)
+
+var defaultUnderrideRules = []*Rule{
+ &mRuleCallDefinition,
+ &mRuleEncryptedRoomOneToOneDefinition,
+ &mRuleRoomOneToOneDefinition,
+ &mRuleMessageDefinition,
+ &mRuleEncryptedDefinition,
+}
+
+var (
+ mRuleCallDefinition = Rule{
+ RuleID: MRuleCall,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.call.invite",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: SoundTweak,
+ Value: "ring",
+ },
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleEncryptedRoomOneToOneDefinition = Rule{
+ RuleID: MRuleEncryptedRoomOneToOne,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: RoomMemberCountCondition,
+ Is: "2",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.encrypted",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleRoomOneToOneDefinition = Rule{
+ RuleID: MRuleRoomOneToOne,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: RoomMemberCountCondition,
+ Is: "2",
+ },
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.message",
+ },
+ },
+ Actions: []*Action{
+ {Kind: NotifyAction},
+ {
+ Kind: SetTweakAction,
+ Tweak: HighlightTweak,
+ Value: false,
+ },
+ },
+ }
+ mRuleMessageDefinition = Rule{
+ RuleID: MRuleMessage,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.message",
+ },
+ },
+ Actions: []*Action{{Kind: NotifyAction}},
+ }
+ mRuleEncryptedDefinition = Rule{
+ RuleID: MRuleEncrypted,
+ Default: true,
+ Enabled: true,
+ Conditions: []*Condition{
+ {
+ Kind: EventMatchCondition,
+ Key: "type",
+ Pattern: "m.room.encrypted",
+ },
+ },
+ Actions: []*Action{{Kind: NotifyAction}},
+ }
+)
diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go
new file mode 100644
index 00000000..df22cb04
--- /dev/null
+++ b/internal/pushrules/evaluate.go
@@ -0,0 +1,165 @@
+package pushrules
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+// A RuleSetEvaluator encapsulates context to evaluate an event
+// against a rule set.
+type RuleSetEvaluator struct {
+ ec EvaluationContext
+ ruleSet []kindAndRules
+}
+
+// An EvaluationContext gives a RuleSetEvaluator access to the
+// environment, for rules that require that.
+type EvaluationContext interface {
+ // UserDisplayName returns the current user's display name.
+ UserDisplayName() string
+
+ // RoomMemberCount returns the number of members in the room of
+ // the current event.
+ RoomMemberCount() (int, error)
+
+ // HasPowerLevel returns whether the user has at least the given
+ // power in the room of the current event.
+ HasPowerLevel(userID, levelKey string) (bool, error)
+}
+
+// A kindAndRules is just here to simplify iteration of the (ordered)
+// kinds of rules.
+type kindAndRules struct {
+ Kind Kind
+ Rules []*Rule
+}
+
+// NewRuleSetEvaluator creates a new evaluator for the given rule set.
+func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator {
+ return &RuleSetEvaluator{
+ ec: ec,
+ ruleSet: []kindAndRules{
+ {OverrideKind, ruleSet.Override},
+ {ContentKind, ruleSet.Content},
+ {RoomKind, ruleSet.Room},
+ {SenderKind, ruleSet.Sender},
+ {UnderrideKind, ruleSet.Underride},
+ },
+ }
+}
+
+// MatchEvent returns the first matching rule. Returns nil if there
+// was no match rule.
+func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) {
+ // TODO: server-default rules have lower priority than user rules,
+ // but they are stored together with the user rules. It's a bit
+ // unclear what the specification (11.14.1.4 Predefined rules)
+ // means the ordering should be.
+ //
+ // The most reasonable interpretation is that default overrides
+ // still have lower priority than user content rules, so we
+ // iterate twice.
+ for _, rsat := range rse.ruleSet {
+ for _, defRules := range []bool{false, true} {
+ for _, rule := range rsat.Rules {
+ if rule.Default != defRules {
+ continue
+ }
+ ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
+ if err != nil {
+ return nil, err
+ }
+ if ok {
+ return rule, nil
+ }
+ }
+ }
+ }
+
+ // No matching rule.
+ return nil, nil
+}
+
+func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
+ if !rule.Enabled {
+ return false, nil
+ }
+
+ switch kind {
+ case OverrideKind, UnderrideKind:
+ for _, cond := range rule.Conditions {
+ ok, err := conditionMatches(cond, event, ec)
+ if err != nil {
+ return false, err
+ }
+ if !ok {
+ return false, nil
+ }
+ }
+ return true, nil
+
+ case ContentKind:
+ // TODO: "These configure behaviour for (unencrypted) messages
+ // that match certain patterns." - Does that mean "content.body"?
+ return patternMatches("content.body", rule.Pattern, event)
+
+ case RoomKind:
+ return rule.RuleID == event.RoomID(), nil
+
+ case SenderKind:
+ return rule.RuleID == event.Sender(), nil
+
+ default:
+ return false, nil
+ }
+}
+
+func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
+ switch cond.Kind {
+ case EventMatchCondition:
+ return patternMatches(cond.Key, cond.Pattern, event)
+
+ case ContainsDisplayNameCondition:
+ return patternMatches("content.body", ec.UserDisplayName(), event)
+
+ case RoomMemberCountCondition:
+ cmp, err := parseRoomMemberCountCondition(cond.Is)
+ if err != nil {
+ return false, fmt.Errorf("parsing room_member_count condition: %w", err)
+ }
+ n, err := ec.RoomMemberCount()
+ if err != nil {
+ return false, fmt.Errorf("RoomMemberCount failed: %w", err)
+ }
+ return cmp(n), nil
+
+ case SenderNotificationPermissionCondition:
+ return ec.HasPowerLevel(event.Sender(), cond.Key)
+
+ default:
+ return false, nil
+ }
+}
+
+func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
+ re, err := globToRegexp(pattern)
+ if err != nil {
+ return false, err
+ }
+
+ var eventMap map[string]interface{}
+ if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
+ return false, fmt.Errorf("parsing event: %w", err)
+ }
+ v, err := lookupMapPath(strings.Split(key, "."), eventMap)
+ if err != nil {
+ // An unknown path is a benign error that shouldn't stop rule
+ // processing. It's just a non-match.
+ return false, nil
+ }
+
+ return re.MatchString(fmt.Sprint(v)), nil
+}
diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go
new file mode 100644
index 00000000..50e70336
--- /dev/null
+++ b/internal/pushrules/evaluate_test.go
@@ -0,0 +1,189 @@
+package pushrules
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
+ ev := mustEventFromJSON(t, `{}`)
+ defaultEnabled := &Rule{
+ RuleID: ".default.enabled",
+ Default: true,
+ Enabled: true,
+ }
+ userEnabled := &Rule{
+ RuleID: ".user.enabled",
+ Default: false,
+ Enabled: true,
+ }
+ userEnabled2 := &Rule{
+ RuleID: ".user.enabled.2",
+ Default: false,
+ Enabled: true,
+ }
+ tsts := []struct {
+ Name string
+ RuleSet RuleSet
+ Want *Rule
+ }{
+ {"empty", RuleSet{}, nil},
+ {"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled},
+ {"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled},
+ {"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled},
+ {"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled},
+ {"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled},
+ {"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled},
+ {"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ rse := NewRuleSetEvaluator(nil, &tst.RuleSet)
+ got, err := rse.MatchEvent(ev)
+ if err != nil {
+ t.Fatalf("MatchEvent failed: %v", err)
+ }
+ if diff := cmp.Diff(tst.Want, got); diff != "" {
+ t.Errorf("MatchEvent rule: +got -want:\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestRuleMatches(t *testing.T) {
+ emptyRule := Rule{Enabled: true}
+ tsts := []struct {
+ Name string
+ Kind Kind
+ Rule Rule
+ EventJSON string
+ Want bool
+ }{
+ {"emptyOverride", OverrideKind, emptyRule, `{}`, true},
+ {"emptyContent", ContentKind, emptyRule, `{}`, false},
+ {"emptyRoom", RoomKind, emptyRule, `{}`, true},
+ {"emptySender", SenderKind, emptyRule, `{}`, true},
+ {"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
+
+ {"disabled", OverrideKind, Rule{}, `{}`, false},
+
+ {"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true},
+ {"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
+
+ {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
+ {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
+
+ {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
+ {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
+
+ {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
+ {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
+
+ {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
+ {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
+ if err != nil {
+ t.Fatalf("ruleMatches failed: %v", err)
+ }
+ if got != tst.Want {
+ t.Errorf("ruleMatches: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+func TestConditionMatches(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Cond Condition
+ EventJSON string
+ Want bool
+ }{
+ {"empty", Condition{}, `{}`, false},
+ {"empty", Condition{Kind: "unknownstring"}, `{}`, false},
+
+ {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true},
+
+ {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
+ {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
+
+ {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
+ {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
+ {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
+ {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
+ {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
+ {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
+ {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
+ {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
+ {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
+ {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
+
+ {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
+ {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{})
+ if err != nil {
+ t.Fatalf("conditionMatches failed: %v", err)
+ }
+ if got != tst.Want {
+ t.Errorf("conditionMatches: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+type fakeEvaluationContext struct{}
+
+func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
+func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil }
+func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) {
+ return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil
+}
+
+func TestPatternMatches(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Key string
+ Pattern string
+ EventJSON string
+ Want bool
+ }{
+ {"empty", "", "", `{}`, false},
+
+ // Note that an empty pattern contains no wildcard characters,
+ // which implicitly means "*".
+ {"patternEmpty", "content", "", `{"content":{}}`, true},
+
+ {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
+ {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
+ {"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true},
+ {"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true},
+ {"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON))
+ if err != nil {
+ t.Fatalf("patternMatches failed: %v", err)
+ }
+ if got != tst.Want {
+ t.Errorf("patternMatches: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event {
+ ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ev
+}
diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go
new file mode 100644
index 00000000..bbed1f95
--- /dev/null
+++ b/internal/pushrules/pushrules.go
@@ -0,0 +1,71 @@
+package pushrules
+
+// An AccountRuleSets carries the rule sets associated with an
+// account.
+type AccountRuleSets struct {
+ Global RuleSet `json:"global"` // Required
+}
+
+// A RuleSet contains all the various push rules for an
+// account. Listed in decreasing order of priority.
+type RuleSet struct {
+ Override []*Rule `json:"override,omitempty"`
+ Content []*Rule `json:"content,omitempty"`
+ Room []*Rule `json:"room,omitempty"`
+ Sender []*Rule `json:"sender,omitempty"`
+ Underride []*Rule `json:"underride,omitempty"`
+}
+
+// A Rule contains matchers, conditions and final actions. While
+// evaluating, at most one rule is considered matching.
+//
+// Kind and scope are part of the push rules request/responses, but
+// not of the core data model.
+type Rule struct {
+ // RuleID is either a free identifier, or the sender's MXID for
+ // SenderKind. Required.
+ RuleID string `json:"rule_id"`
+
+ // Default indicates whether this is a server-defined default, or
+ // a user-provided rule. Required.
+ //
+ // The server-default rules have the lowest priority.
+ Default bool `json:"default"`
+
+ // Enabled allows the user to disable rules while keeping them
+ // around. Required.
+ Enabled bool `json:"enabled"`
+
+ // Actions describe the desired outcome, should the rule
+ // match. Required.
+ Actions []*Action `json:"actions"`
+
+ // Conditions provide the rule's conditions for OverrideKind and
+ // UnderrideKind. Not allowed for other kinds.
+ Conditions []*Condition `json:"conditions"`
+
+ // Pattern is the body pattern to match for ContentKind. Required
+ // for that kind. The interpretation is the same as that of
+ // Condition.Pattern.
+ Pattern string `json:"pattern"`
+}
+
+// Scope only has one valid value. See also AccountRuleSets.
+type Scope string
+
+const (
+ UnknownScope Scope = ""
+ GlobalScope Scope = "global"
+)
+
+// Kind is the type of push rule. See also RuleSet.
+type Kind string
+
+const (
+ UnknownKind Kind = ""
+ OverrideKind Kind = "override"
+ ContentKind Kind = "content"
+ RoomKind Kind = "room"
+ SenderKind Kind = "sender"
+ UnderrideKind Kind = "underride"
+)
diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go
new file mode 100644
index 00000000..027d35ef
--- /dev/null
+++ b/internal/pushrules/util.go
@@ -0,0 +1,125 @@
+package pushrules
+
+import (
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+)
+
+// ActionsToTweaks converts a list of actions into a primary action
+// kind and a tweaks map. Returns a nil map if it would have been
+// empty.
+func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
+ var kind ActionKind
+ tweaks := map[string]interface{}{}
+
+ for _, a := range as {
+ if a.Kind == SetTweakAction {
+ tweaks[string(a.Tweak)] = a.Value
+ continue
+ }
+ if kind != UnknownAction {
+ return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
+ }
+ kind = a.Kind
+ }
+
+ if len(tweaks) == 0 {
+ tweaks = nil
+ }
+
+ return kind, tweaks, nil
+}
+
+// BoolTweakOr returns the named tweak as a boolean, and returns `def`
+// on failure.
+func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool {
+ v, ok := tweaks[string(key)]
+ if !ok {
+ return def
+ }
+ b, ok := v.(bool)
+ if !ok {
+ return def
+ }
+ return b
+}
+
+// globToRegexp converts a Matrix glob-style pattern to a Regular expression.
+func globToRegexp(pattern string) (*regexp.Regexp, error) {
+ // TODO: It's unclear which glob characters are supported. The only
+ // place this is discussed is for the unrelated "m.policy.rule.*"
+ // events. Assuming, the same: /[*?]/
+ if !strings.ContainsAny(pattern, "*?") {
+ pattern = "*" + pattern + "*"
+ }
+
+ // The defined syntax doesn't allow escaping the glob wildcard
+ // characters, which makes this a straight-forward
+ // replace-after-quote.
+ pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta)
+ pattern = strings.Replace(pattern, "*", ".*", -1)
+ pattern = strings.Replace(pattern, "?", ".", -1)
+ return regexp.Compile("^(" + pattern + ")$")
+}
+
+// globNonMetaRegexp are the characters that are not considered glob
+// meta-characters (i.e. may need escaping).
+var globNonMetaRegexp = regexp.MustCompile("[^*?]+")
+
+// lookupMapPath traverses a hierarchical map structure, like the one
+// produced by json.Unmarshal, to return the leaf value. Traversing
+// arrays/slices is not supported, only objects/maps.
+func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) {
+ if len(path) == 0 {
+ return nil, fmt.Errorf("empty path")
+ }
+
+ var v interface{} = m
+ for i, key := range path {
+ m, ok := v.(map[string]interface{})
+ if !ok {
+ return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v)
+ }
+
+ v, ok = m[key]
+ if !ok {
+ return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], "."))
+ }
+ }
+
+ return v, nil
+}
+
+// parseRoomMemberCountCondition parses a string like "2", "==2", "<2"
+// into a function that checks if the argument to it fulfils the
+// condition.
+func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
+ var b int
+ var cmp = func(a int) bool { return a == b }
+ switch {
+ case strings.HasPrefix(s, "<="):
+ cmp = func(a int) bool { return a <= b }
+ s = s[2:]
+ case strings.HasPrefix(s, ">="):
+ cmp = func(a int) bool { return a >= b }
+ s = s[2:]
+ case strings.HasPrefix(s, "<"):
+ cmp = func(a int) bool { return a < b }
+ s = s[1:]
+ case strings.HasPrefix(s, ">"):
+ cmp = func(a int) bool { return a > b }
+ s = s[1:]
+ case strings.HasPrefix(s, "=="):
+ // Same cmp as the default.
+ s = s[2:]
+ }
+
+ v, err := strconv.ParseInt(s, 10, 64)
+ if err != nil {
+ return nil, err
+ }
+ b = int(v)
+ return cmp, nil
+}
diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go
new file mode 100644
index 00000000..a951c55a
--- /dev/null
+++ b/internal/pushrules/util_test.go
@@ -0,0 +1,169 @@
+package pushrules
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestActionsToTweaks(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Input []*Action
+ WantKind ActionKind
+ WantTweaks map[string]interface{}
+ }{
+ {"empty", nil, UnknownAction, nil},
+ {"zero", []*Action{{}}, UnknownAction, nil},
+ {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
+ {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
+ {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
+ {
+ "all",
+ []*Action{
+ {Kind: CoalesceAction},
+ {Kind: SetTweakAction, Tweak: HighlightTweak},
+ {Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"},
+ },
+ CoalesceAction,
+ map[string]interface{}{"highlight": nil, "sound": "default"},
+ },
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ gotKind, gotTweaks, err := ActionsToTweaks(tst.Input)
+ if err != nil {
+ t.Fatalf("ActionsToTweaks failed: %v", err)
+ }
+ if gotKind != tst.WantKind {
+ t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind)
+ }
+ if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" {
+ t.Errorf("tweaks: +got -want:\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestBoolTweakOr(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Input map[string]interface{}
+ Def bool
+ Want bool
+ }{
+ {"nil", nil, false, false},
+ {"nilValue", map[string]interface{}{"highlight": nil}, true, true},
+ {"false", map[string]interface{}{"highlight": false}, true, false},
+ {"true", map[string]interface{}{"highlight": true}, false, true},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def)
+ if got != tst.Want {
+ t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want)
+ }
+ })
+ }
+}
+
+func TestGlobToRegexp(t *testing.T) {
+ tsts := []struct {
+ Input string
+ Want string
+ }{
+ {"", "^(.*.*)$"},
+ {"a", "^(.*a.*)$"},
+ {"a.b", "^(.*a\\.b.*)$"},
+ {"a?b", "^(a.b)$"},
+ {"a*b*", "^(a.*b.*)$"},
+ {"a*b?", "^(a.*b.)$"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Want, func(t *testing.T) {
+ got, err := globToRegexp(tst.Input)
+ if err != nil {
+ t.Fatalf("globToRegexp failed: %v", err)
+ }
+ if got.String() != tst.Want {
+ t.Errorf("got %v, want %v", got.String(), tst.Want)
+ }
+ })
+ }
+}
+
+func TestLookupMapPath(t *testing.T) {
+ tsts := []struct {
+ Path []string
+ Root map[string]interface{}
+ Want interface{}
+ }{
+ {[]string{"a"}, map[string]interface{}{"a": "b"}, "b"},
+ {[]string{"a"}, map[string]interface{}{"a": 42}, 42},
+ {[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"},
+ }
+ for _, tst := range tsts {
+ t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) {
+ got, err := lookupMapPath(tst.Path, tst.Root)
+ if err != nil {
+ t.Fatalf("lookupMapPath failed: %v", err)
+ }
+ if diff := cmp.Diff(tst.Want, got); diff != "" {
+ t.Errorf("+got -want:\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLookupMapPathInvalid(t *testing.T) {
+ tsts := []struct {
+ Path []string
+ Root map[string]interface{}
+ }{
+ {nil, nil},
+ {[]string{"a"}, nil},
+ {[]string{"a", "b"}, map[string]interface{}{"a": "c"}},
+ }
+ for _, tst := range tsts {
+ t.Run(fmt.Sprint(tst.Path), func(t *testing.T) {
+ got, err := lookupMapPath(tst.Path, tst.Root)
+ if err == nil {
+ t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got)
+ }
+ })
+ }
+}
+
+func TestParseRoomMemberCountCondition(t *testing.T) {
+ tsts := []struct {
+ Input string
+ WantTrue []int
+ WantFalse []int
+ }{
+ {"1", []int{1}, []int{0, 2}},
+ {"==1", []int{1}, []int{0, 2}},
+ {"<1", []int{0}, []int{1, 2}},
+ {"<=1", []int{0, 1}, []int{2}},
+ {">1", []int{2}, []int{0, 1}},
+ {">=42", []int{42, 43}, []int{41}},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Input, func(t *testing.T) {
+ got, err := parseRoomMemberCountCondition(tst.Input)
+ if err != nil {
+ t.Fatalf("parseRoomMemberCountCondition failed: %v", err)
+ }
+ for _, v := range tst.WantTrue {
+ if !got(v) {
+ t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v)
+ }
+ }
+ for _, v := range tst.WantFalse {
+ if got(v) {
+ t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v)
+ }
+ }
+ })
+ }
+}
diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go
new file mode 100644
index 00000000..5d260f0b
--- /dev/null
+++ b/internal/pushrules/validate.go
@@ -0,0 +1,85 @@
+package pushrules
+
+import (
+ "fmt"
+ "regexp"
+)
+
+// ValidateRule checks the rule for errors. These follow from Sytests
+// and the specification.
+func ValidateRule(kind Kind, rule *Rule) []error {
+ var errs []error
+
+ if !validRuleIDRE.MatchString(rule.RuleID) {
+ errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
+ }
+
+ if len(rule.Actions) == 0 {
+ errs = append(errs, fmt.Errorf("missing actions"))
+ }
+ for _, action := range rule.Actions {
+ errs = append(errs, validateAction(action)...)
+ }
+
+ for _, cond := range rule.Conditions {
+ errs = append(errs, validateCondition(cond)...)
+ }
+
+ switch kind {
+ case OverrideKind, UnderrideKind:
+ // The empty list is allowed, but for JSON-encoding reasons,
+ // it must not be nil.
+ if rule.Conditions == nil {
+ errs = append(errs, fmt.Errorf("missing rule conditions"))
+ }
+
+ case ContentKind:
+ if rule.Pattern == "" {
+ errs = append(errs, fmt.Errorf("missing content rule pattern"))
+ }
+
+ case RoomKind, SenderKind:
+ // Do nothing.
+
+ default:
+ errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind))
+ }
+
+ return errs
+}
+
+// validRuleIDRE is a regexp for valid IDs.
+//
+// TODO: the specification doesn't seem to say what the rule ID syntax
+// is. A Sytest fails if it contains a backslash.
+var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`)
+
+// validateAction returns issues with an Action.
+func validateAction(action *Action) []error {
+ var errs []error
+
+ switch action.Kind {
+ case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction:
+ // Do nothing.
+
+ default:
+ errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind))
+ }
+
+ return errs
+}
+
+// validateCondition returns issues with a Condition.
+func validateCondition(cond *Condition) []error {
+ var errs []error
+
+ switch cond.Kind {
+ case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition:
+ // Do nothing.
+
+ default:
+ errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind))
+ }
+
+ return errs
+}
diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go
new file mode 100644
index 00000000..b276eb55
--- /dev/null
+++ b/internal/pushrules/validate_test.go
@@ -0,0 +1,163 @@
+package pushrules
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestValidateRuleNegatives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Kind Kind
+ Rule Rule
+ WantErrString string
+ }{
+ {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
+ {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
+ {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
+ {"noActions", OverrideKind, Rule{}, "missing actions"},
+ {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
+ {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
+ {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
+ {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
+ {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := ValidateRule(tst.Kind, &tst.Rule)
+ var foundErr error
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantErrString) {
+ foundErr = err
+ }
+ }
+ if foundErr == nil {
+ t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
+ }
+ })
+ }
+}
+
+func TestValidateRulePositives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Kind Kind
+ Rule Rule
+ WantNoErrString string
+ }{
+ {"invalidKind", OverrideKind, Rule{}, "invalid rule kind"},
+ {"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"},
+ {"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"},
+ {"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"},
+ {"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"},
+ {"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"},
+ {"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"},
+ {"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := ValidateRule(tst.Kind, &tst.Rule)
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantNoErrString) {
+ t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
+ }
+ }
+ })
+ }
+}
+
+func TestValidateActionNegatives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Action Action
+ WantErrString string
+ }{
+ {"emptyKind", Action{}, "invalid rule action kind"},
+ {"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateAction(&tst.Action)
+ var foundErr error
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantErrString) {
+ foundErr = err
+ }
+ }
+ if foundErr == nil {
+ t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
+ }
+ })
+ }
+}
+
+func TestValidateActionPositives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Action Action
+ WantNoErrString string
+ }{
+ {"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateAction(&tst.Action)
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantNoErrString) {
+ t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
+ }
+ }
+ })
+ }
+}
+
+func TestValidateConditionNegatives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Condition Condition
+ WantErrString string
+ }{
+ {"emptyKind", Condition{}, "invalid rule condition kind"},
+ {"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateCondition(&tst.Condition)
+ var foundErr error
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantErrString) {
+ foundErr = err
+ }
+ }
+ if foundErr == nil {
+ t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
+ }
+ })
+ }
+}
+
+func TestValidateConditionPositives(t *testing.T) {
+ tsts := []struct {
+ Name string
+ Condition Condition
+ WantNoErrString string
+ }{
+ {"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"},
+ }
+ for _, tst := range tsts {
+ t.Run(tst.Name, func(t *testing.T) {
+ errs := validateCondition(&tst.Condition)
+ for _, err := range errs {
+ t.Logf("Got error %#v", err)
+ if strings.Contains(err.Error(), tst.WantNoErrString) {
+ t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
+ }
+ }
+ })
+ }
+}
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go
index 8d0d2dfa..19483b26 100644
--- a/internal/sqlutil/sql.go
+++ b/internal/sqlutil/sql.go
@@ -163,6 +163,7 @@ type StatementList []struct {
func (s StatementList) Prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
+ err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL)
return
}
}
diff --git a/q.sqlite b/q.sqlite
new file mode 100644
index 00000000..b7d6268e
--- /dev/null
+++ b/q.sqlite
Binary files differ
diff --git a/run-sytest.sh b/run-sytest.sh
new file mode 100755
index 00000000..47635fd1
--- /dev/null
+++ b/run-sytest.sh
@@ -0,0 +1,63 @@
+#!/bin/bash
+#
+# Runs SyTest either from Docker Hub, or from ../sytest. If it's run
+# locally, the Docker image is rebuilt first.
+#
+# Logs are stored in ../sytestout/logs.
+
+set -e
+set -o pipefail
+
+main() {
+ local tag=buster
+ local base_image=debian:$tag
+ local runargs=()
+
+ cd "$(dirname "$0")"
+
+ if [ -d ../sytest ]; then
+ local tmpdir
+ tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)"
+ trap "rm -r '$tmpdir'" EXIT
+
+ if [ -z "$DISABLE_BUILDING_SYTEST" ]; then
+ echo "Re-building ../sytest Docker images..."
+
+ local status
+ (
+ cd ../sytest
+
+ docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" .
+ docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest .
+ ) &>"$tmpdir/buildlog" || status=$?
+ if (( status != 0 )); then
+ # Docker is very verbose, and we don't really care about
+ # building SyTest. So we accumulate and only output on
+ # failure.
+ cat "$tmpdir/buildlog" >&2
+ return $status
+ fi
+ fi
+
+ runargs+=( -v "$PWD/../sytest:/sytest:ro" )
+ fi
+ if [ -n "$SYTEST_POSTGRES" ]; then
+ runargs+=( -e POSTGRES=1 )
+ fi
+
+ local sytestout=$PWD/../sytestout
+ mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg}
+ docker run \
+ --rm \
+ --name "sytest-dendrite-${LOGNAME}" \
+ -e LOGS_USER=$(id -u) \
+ -e LOGS_GROUP=$(id -g) \
+ -v "$PWD:/src/:ro" \
+ -v "$sytestout/logs:/logs/" \
+ -v "$sytestout/cache/go-build:/root/.cache/go-build" \
+ -v "$sytestout/cache/go-pkg:/gopath/pkg" \
+ "${runargs[@]}" \
+ matrixdotorg/sytest-dendrite:latest "$@"
+}
+
+main "$@"
diff --git a/setup/base/base.go b/setup/base/base.go
index e3997754..ef3b2be2 100644
--- a/setup/base/base.go
+++ b/setup/base/base.go
@@ -30,6 +30,7 @@ import (
sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
+ "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/atomic"
@@ -271,6 +272,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
return f
}
+// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways.
+func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
+ return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation)
+}
+
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
diff --git a/setup/config/config_test.go b/setup/config/config_test.go
index 8f7611f0..6467b7c8 100644
--- a/setup/config/config_test.go
+++ b/setup/config/config_test.go
@@ -205,6 +205,11 @@ user_api:
max_open_conns: 100
max_idle_conns: 2
conn_max_lifetime: -1
+ pusher_database:
+ connection_string: file:pushserver.db
+ max_open_conns: 100
+ max_idle_conns: 2
+ conn_max_lifetime: -1
tracing:
enabled: false
jaeger:
diff --git a/setup/config/config_userapi.go b/setup/config/config_userapi.go
index 1cb5eba1..570dc603 100644
--- a/setup/config/config_userapi.go
+++ b/setup/config/config_userapi.go
@@ -13,6 +13,9 @@ type UserAPI struct {
// The length of time an OpenID token is condidered valid in milliseconds
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
+ // Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED!
+ PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"`
+
// The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"`
diff --git a/setup/jetstream/streams.go b/setup/jetstream/streams.go
index 5810a2a9..3f07488f 100644
--- a/setup/jetstream/streams.go
+++ b/setup/jetstream/streams.go
@@ -18,7 +18,10 @@ var (
OutputKeyChangeEvent = "OutputKeyChangeEvent"
OutputTypingEvent = "OutputTypingEvent"
OutputClientData = "OutputClientData"
+ OutputNotificationData = "OutputNotificationData"
OutputReceiptEvent = "OutputReceiptEvent"
+ OutputStreamEvent = "OutputStreamEvent"
+ OutputReadUpdate = "OutputReadUpdate"
)
var streams = []*nats.StreamConfig{
@@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
+ {
+ Name: OutputNotificationData,
+ Retention: nats.InterestPolicy,
+ Storage: nats.FileStorage,
+ },
+ {
+ Name: OutputStreamEvent,
+ Retention: nats.InterestPolicy,
+ Storage: nats.FileStorage,
+ },
+ {
+ Name: OutputReadUpdate,
+ Retention: nats.InterestPolicy,
+ Storage: nats.FileStorage,
+ },
}
diff --git a/setup/monolith.go b/setup/monolith.go
index 61125e4a..7dbd2eea 100644
--- a/setup/monolith.go
+++ b/setup/monolith.go
@@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss
csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB,
m.FedClient, m.RoomserverAPI,
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
- m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider,
- &m.Config.MSCs,
+ m.FederationAPI, m.UserAPI, m.KeyAPI,
+ m.ExtPublicRoomsProvider, &m.Config.MSCs,
)
federationapi.AddPublicRoutes(
ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,
diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go
index c3650085..f01afce6 100644
--- a/syncapi/consumers/clientapi.go
+++ b/syncapi/consumers/clientapi.go
@@ -17,6 +17,7 @@ package consumers
import (
"context"
"encoding/json"
+ "fmt"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/internal/eventutil"
@@ -24,21 +25,26 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
+ "github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
+ "github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
// OutputClientDataConsumer consumes events that originated in the client API server.
type OutputClientDataConsumer struct {
- ctx context.Context
- jetstream nats.JetStreamContext
- durable string
- topic string
- db storage.Database
- stream types.StreamProvider
- notifier *notifier.Notifier
+ ctx context.Context
+ jetstream nats.JetStreamContext
+ durable string
+ topic string
+ db storage.Database
+ stream types.StreamProvider
+ notifier *notifier.Notifier
+ serverName gomatrixserverlib.ServerName
+ producer *producers.UserAPIReadProducer
}
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
@@ -49,15 +55,18 @@ func NewOutputClientDataConsumer(
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
+ producer *producers.UserAPIReadProducer,
) *OutputClientDataConsumer {
return &OutputClientDataConsumer{
- ctx: process.Context(),
- jetstream: js,
- topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
- durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"),
- db: store,
- notifier: notifier,
- stream: stream,
+ ctx: process.Context(),
+ jetstream: js,
+ topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
+ durable: cfg.Matrix.JetStream.Durable("SyncAPIClientAPIConsumer"),
+ db: store,
+ notifier: notifier,
+ stream: stream,
+ serverName: cfg.Matrix.ServerName,
+ producer: producer,
}
}
@@ -100,8 +109,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg)
}).Panicf("could not save account data")
}
+ if err = s.sendReadUpdate(ctx, userID, output); err != nil {
+ log.WithError(err).WithFields(logrus.Fields{
+ "user_id": userID,
+ "room_id": output.RoomID,
+ }).Errorf("Failed to generate read update")
+ sentry.CaptureException(err)
+ return false
+ }
+
s.stream.Advance(streamPos)
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
return true
}
+
+func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
+ if output.Type != "m.fully_read" || output.ReadMarker == nil {
+ return nil
+ }
+ _, serverName, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
+ }
+ if serverName != s.serverName {
+ return nil
+ }
+ var readPos types.StreamPosition
+ var fullyReadPos types.StreamPosition
+ if output.ReadMarker.Read != "" {
+ if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil {
+ return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
+ }
+ }
+ if output.ReadMarker.FullyRead != "" {
+ if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil {
+ return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
+ }
+ }
+ if readPos > 0 || fullyReadPos > 0 {
+ if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
+ return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
+ }
+ }
+ return nil
+}
diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go
index 392840ec..88158344 100644
--- a/syncapi/consumers/eduserver_receipts.go
+++ b/syncapi/consumers/eduserver_receipts.go
@@ -17,6 +17,7 @@ package consumers
import (
"context"
"encoding/json"
+ "fmt"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/eduserver/api"
@@ -24,21 +25,26 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
+ "github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
+ "github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
// OutputReceiptEventConsumer consumes events that originated in the EDU server.
type OutputReceiptEventConsumer struct {
- ctx context.Context
- jetstream nats.JetStreamContext
- durable string
- topic string
- db storage.Database
- stream types.StreamProvider
- notifier *notifier.Notifier
+ ctx context.Context
+ jetstream nats.JetStreamContext
+ durable string
+ topic string
+ db storage.Database
+ stream types.StreamProvider
+ notifier *notifier.Notifier
+ serverName gomatrixserverlib.ServerName
+ producer *producers.UserAPIReadProducer
}
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@@ -50,15 +56,18 @@ func NewOutputReceiptEventConsumer(
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
+ producer *producers.UserAPIReadProducer,
) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{
- ctx: process.Context(),
- jetstream: js,
- topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent),
- durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"),
- db: store,
- notifier: notifier,
- stream: stream,
+ ctx: process.Context(),
+ jetstream: js,
+ topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent),
+ durable: cfg.Matrix.JetStream.Durable("SyncAPIEDUServerReceiptConsumer"),
+ db: store,
+ notifier: notifier,
+ stream: stream,
+ serverName: cfg.Matrix.ServerName,
+ producer: producer,
}
}
@@ -92,8 +101,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms
return true
}
+ if err = s.sendReadUpdate(ctx, output); err != nil {
+ log.WithError(err).WithFields(logrus.Fields{
+ "user_id": output.UserID,
+ "room_id": output.RoomID,
+ }).Errorf("Failed to generate read update")
+ sentry.CaptureException(err)
+ return false
+ }
+
s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true
}
+
+func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error {
+ if output.Type != "m.read" {
+ return nil
+ }
+ _, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
+ if err != nil {
+ return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
+ }
+ if serverName != s.serverName {
+ return nil
+ }
+ var readPos types.StreamPosition
+ if output.EventID != "" {
+ if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil {
+ return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
+ }
+ }
+ if readPos > 0 {
+ if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
+ return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
+ }
+ }
+ return nil
+}
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index 15485bb3..159657f9 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
+ "github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct {
pduStream types.StreamProvider
inviteStream types.StreamProvider
notifier *notifier.Notifier
+ producer *producers.UserAPIStreamEventProducer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer(
pduStream types.StreamProvider,
inviteStream types.StreamProvider,
rsAPI api.RoomserverInternalAPI,
+ producer *producers.UserAPIStreamEventProducer,
) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{
ctx: process.Context(),
@@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream,
inviteStream: inviteStream,
rsAPI: rsAPI,
+ producer: producer,
}
}
@@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return nil
}
+ if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
+ log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
+ sentry.CaptureException(err)
+ return err
+ }
+
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
sentry.CaptureException(err)
diff --git a/syncapi/consumers/userapi.go b/syncapi/consumers/userapi.go
new file mode 100644
index 00000000..a3b2dd53
--- /dev/null
+++ b/syncapi/consumers/userapi.go
@@ -0,0 +1,110 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/getsentry/sentry-go"
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/syncapi/notifier"
+ "github.com/matrix-org/dendrite/syncapi/storage"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+// OutputNotificationDataConsumer consumes events that originated in
+// the Push server.
+type OutputNotificationDataConsumer struct {
+ ctx context.Context
+ jetstream nats.JetStreamContext
+ durable string
+ topic string
+ db storage.Database
+ notifier *notifier.Notifier
+ stream types.StreamProvider
+}
+
+// NewOutputNotificationDataConsumer creates a new consumer. Call
+// Start() to begin consuming.
+func NewOutputNotificationDataConsumer(
+ process *process.ProcessContext,
+ cfg *config.SyncAPI,
+ js nats.JetStreamContext,
+ store storage.Database,
+ notifier *notifier.Notifier,
+ stream types.StreamProvider,
+) *OutputNotificationDataConsumer {
+ s := &OutputNotificationDataConsumer{
+ ctx: process.Context(),
+ jetstream: js,
+ durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"),
+ topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
+ db: store,
+ notifier: notifier,
+ stream: stream,
+ }
+ return s
+}
+
+// Start starts consumption.
+func (s *OutputNotificationDataConsumer) Start() error {
+ return jetstream.JetStreamConsumer(
+ s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
+ nats.DeliverAll(), nats.ManualAck(),
+ )
+}
+
+// onMessage is called when the Sync server receives a new event from
+// the push server. It is not safe for this function to be called from
+// multiple goroutines, or else the sync stream position may race and
+// be incorrectly calculated.
+func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
+ userID := string(msg.Header.Get(jetstream.UserID))
+
+ // Parse out the event JSON
+ var data eventutil.NotificationData
+ if err := json.Unmarshal(msg.Data, &data); err != nil {
+ sentry.CaptureException(err)
+ log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure")
+ return true
+ }
+
+ streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount)
+ if err != nil {
+ sentry.CaptureException(err)
+ log.WithFields(log.Fields{
+ "user_id": userID,
+ "room_id": data.RoomID,
+ }).WithError(err).Error("Could not save notification counts")
+ return false
+ }
+
+ s.stream.Advance(streamPos)
+ s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos})
+
+ log.WithFields(log.Fields{
+ "user_id": userID,
+ "room_id": data.RoomID,
+ "streamPos": streamPos,
+ }).Trace("Received notification data from user API")
+
+ return true
+}
diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go
index d853cc0e..6a641e6f 100644
--- a/syncapi/notifier/notifier.go
+++ b/syncapi/notifier/notifier.go
@@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite(
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
}
+func (n *Notifier) OnNewNotificationData(
+ userID string,
+ posUpdate types.StreamingToken,
+) {
+ n.streamLock.Lock()
+ defer n.streamLock.Unlock()
+
+ n.currPos.ApplyUpdates(posUpdate)
+ n.wakeupUsers([]string{userID}, nil, n.currPos)
+}
+
// GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed.
// notify for anything before sincePos
diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go
index c6d3df7e..60403d5d 100644
--- a/syncapi/notifier/notifier_test.go
+++ b/syncapi/notifier/notifier_test.go
@@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) {
go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
if err != nil {
- t.Errorf("TestNewInviteEventForUser error: %w", err)
+ t.Errorf("TestNewInviteEventForUser error: %v", err)
}
mustEqualPositions(t, pos, syncPositionNewEDU)
wg.Done()
diff --git a/syncapi/producers/userapi_readupdate.go b/syncapi/producers/userapi_readupdate.go
new file mode 100644
index 00000000..d56cab77
--- /dev/null
+++ b/syncapi/producers/userapi_readupdate.go
@@ -0,0 +1,62 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package producers
+
+import (
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+// UserAPIProducer produces events for the user API server to consume
+type UserAPIReadProducer struct {
+ Topic string
+ JetStream nats.JetStreamContext
+}
+
+// SendData sends account data to the user API server
+func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
+ m := &nats.Msg{
+ Subject: p.Topic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.UserID, userID)
+ m.Header.Set(jetstream.RoomID, roomID)
+
+ data := types.ReadUpdate{
+ UserID: userID,
+ RoomID: roomID,
+ Read: readPos,
+ FullyRead: fullyReadPos,
+ }
+ var err error
+ m.Data, err = json.Marshal(data)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "user_id": userID,
+ "room_id": roomID,
+ "read_pos": readPos,
+ "fully_read_pos": fullyReadPos,
+ }).Tracef("Producing to topic '%s'", p.Topic)
+
+ _, err = p.JetStream.PublishMsg(m)
+ return err
+}
diff --git a/syncapi/producers/userapi_streamevent.go b/syncapi/producers/userapi_streamevent.go
new file mode 100644
index 00000000..2bbd19c0
--- /dev/null
+++ b/syncapi/producers/userapi_streamevent.go
@@ -0,0 +1,60 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package producers
+
+import (
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+// UserAPIProducer produces events for the user API server to consume
+type UserAPIStreamEventProducer struct {
+ Topic string
+ JetStream nats.JetStreamContext
+}
+
+// SendData sends account data to the user API server
+func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
+ m := &nats.Msg{
+ Subject: p.Topic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.RoomID, roomID)
+
+ data := types.StreamedEvent{
+ Event: event,
+ StreamPosition: pos,
+ }
+ var err error
+ m.Data, err = json.Marshal(data)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "room_id": roomID,
+ "event_id": event.EventID(),
+ "event_type": event.Type(),
+ "stream_pos": pos,
+ }).Tracef("Producing to topic '%s'", p.Topic)
+
+ _, err = p.JetStream.PublishMsg(m)
+ return err
+}
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 126bc865..e4476633 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -18,6 +18,7 @@ import (
"context"
eduAPI "github.com/matrix-org/dendrite/eduserver/api"
+ "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -31,6 +32,7 @@ type Database interface {
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
+ MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@@ -138,6 +140,12 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
+ // UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
+ UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
+
+ // GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
+ GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
+
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)
diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go
new file mode 100644
index 00000000..f3fc4451
--- /dev/null
+++ b/syncapi/storage/postgres/notification_data_table.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
+ _, err := db.Exec(notificationDataSchema)
+ if err != nil {
+ return nil, err
+ }
+ r := &notificationDataStatements{}
+ return r, sqlutil.StatementList{
+ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
+ {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
+ {&r.selectMaxID, selectMaxNotificationIDSQL},
+ }.Prepare(db)
+}
+
+type notificationDataStatements struct {
+ upsertRoomUnreadCounts *sql.Stmt
+ selectUserUnreadCounts *sql.Stmt
+ selectMaxID *sql.Stmt
+}
+
+const notificationDataSchema = `
+CREATE TABLE IF NOT EXISTS syncapi_notification_data (
+ id BIGSERIAL PRIMARY KEY,
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ notification_count BIGINT NOT NULL DEFAULT 0,
+ highlight_count BIGINT NOT NULL DEFAULT 0,
+ CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id)
+);`
+
+const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
+ (user_id, room_id, notification_count, highlight_count)
+ VALUES ($1, $2, $3, $4)
+ ON CONFLICT (user_id, room_id)
+ DO UPDATE SET notification_count = $3, highlight_count = $4
+ RETURNING id`
+
+const selectUserUnreadNotificationCountsSQL = `SELECT
+ id, room_id, notification_count, highlight_count
+ FROM syncapi_notification_data
+ WHERE
+ user_id = $1 AND
+ id BETWEEN $2 + 1 AND $3`
+
+const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
+
+func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
+ err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
+ return
+}
+
+func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
+ rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
+
+ roomCounts := map[string]*eventutil.NotificationData{}
+ for rows.Next() {
+ var id types.StreamPosition
+ var roomID string
+ var notificationCount, highlightCount int
+
+ if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
+ return nil, err
+ }
+
+ roomCounts[roomID] = &eventutil.NotificationData{
+ RoomID: roomID,
+ UnreadNotificationCount: notificationCount,
+ UnreadHighlightCount: highlightCount,
+ }
+ }
+ return roomCounts, rows.Err()
+}
+
+func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
+ var id int64
+ err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
+ return id, err
+}
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index 6f4e7749..60fe5b54 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -90,6 +90,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if err != nil {
return nil, err
}
+ notificationData, err := NewPostgresNotificationDataTable(d.db)
+ if err != nil {
+ return nil, err
+ }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
@@ -110,6 +114,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
SendToDevice: sendToDevice,
Receipts: receipts,
Memberships: memberships,
+ NotificationData: notificationData,
}
return &d, nil
}
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index 819851b3..87d7c6df 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -48,6 +48,7 @@ type Database struct {
Filter tables.Filter
Receipts tables.Receipts
Memberships tables.Memberships
+ NotificationData tables.NotificationData
}
func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) {
@@ -102,6 +103,14 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S
return types.StreamPosition(id), nil
}
+func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
+ id, err := d.NotificationData.SelectMaxID(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
+ }
+ return types.StreamPosition(id), nil
+}
+
func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs)
}
@@ -956,6 +965,18 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream
return receipts, err
}
+func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
+ err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, userID, roomID, notificationCount, highlightCount)
+ return err
+ })
+ return
+}
+
+func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
+ return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
+}
+
func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
}
diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go
new file mode 100644
index 00000000..4b3f074d
--- /dev/null
+++ b/syncapi/storage/sqlite3/notification_data_table.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/syncapi/storage/tables"
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
+ _, err := db.Exec(notificationDataSchema)
+ if err != nil {
+ return nil, err
+ }
+ r := &notificationDataStatements{}
+ return r, sqlutil.StatementList{
+ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
+ {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
+ {&r.selectMaxID, selectMaxNotificationIDSQL},
+ }.Prepare(db)
+}
+
+type notificationDataStatements struct {
+ upsertRoomUnreadCounts *sql.Stmt
+ selectUserUnreadCounts *sql.Stmt
+ selectMaxID *sql.Stmt
+}
+
+const notificationDataSchema = `
+CREATE TABLE IF NOT EXISTS syncapi_notification_data (
+ id INTEGER PRIMARY KEY,
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ notification_count BIGINT NOT NULL DEFAULT 0,
+ highlight_count BIGINT NOT NULL DEFAULT 0,
+ CONSTRAINT syncapi_notifications_unique UNIQUE (user_id, room_id)
+);`
+
+const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
+ (user_id, room_id, notification_count, highlight_count)
+ VALUES ($1, $2, $3, $4)
+ ON CONFLICT (user_id, room_id)
+ DO UPDATE SET notification_count = $3, highlight_count = $4
+ RETURNING id`
+
+const selectUserUnreadNotificationCountsSQL = `SELECT
+ id, room_id, notification_count, highlight_count
+ FROM syncapi_notification_data
+ WHERE
+ user_id = $1 AND
+ id BETWEEN $2 + 1 AND $3`
+
+const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
+
+func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
+ err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
+ return
+}
+
+func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
+ rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
+
+ roomCounts := map[string]*eventutil.NotificationData{}
+ for rows.Next() {
+ var id types.StreamPosition
+ var roomID string
+ var notificationCount, highlightCount int
+
+ if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
+ return nil, err
+ }
+
+ roomCounts[roomID] = &eventutil.NotificationData{
+ RoomID: roomID,
+ UnreadNotificationCount: notificationCount,
+ UnreadHighlightCount: highlightCount,
+ }
+ }
+ return roomCounts, rows.Err()
+}
+
+func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
+ var id int64
+ err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
+ return id, err
+}
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index 581ee692..1b256f91 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -62,16 +62,19 @@ const selectEventsSQL = "" +
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3"
+
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
+
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3"
+
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectMaxEventIDSQL = "" +
@@ -85,6 +88,7 @@ const selectStateInRangeSQL = "" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" +
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
+
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const deleteEventsForRoomSQL = "" +
@@ -95,10 +99,12 @@ const selectContextEventSQL = "" +
const selectContextBeforeEventSQL = "" +
"SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2"
+
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectContextAfterEventSQL = "" +
"SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2"
+
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
type outputRoomEventsStatements struct {
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index 706d43f8..f5ae9fdd 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -100,6 +100,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
if err != nil {
return err
}
+ notificationData, err := NewSqliteNotificationDataTable(d.db)
+ if err != nil {
+ return err
+ }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
@@ -120,6 +124,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
SendToDevice: sendToDevice,
Receipts: receipts,
Memberships: memberships,
+ NotificationData: notificationData,
}
return nil
}
diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go
index 1d807ee6..1ebb4265 100644
--- a/syncapi/storage/tables/interface.go
+++ b/syncapi/storage/tables/interface.go
@@ -19,6 +19,7 @@ import (
"database/sql"
eduAPI "github.com/matrix-org/dendrite/eduserver/api"
+ "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -171,3 +172,9 @@ type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
}
+
+type NotificationData interface {
+ UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
+ SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error)
+ SelectMaxID(ctx context.Context) (int64, error)
+}
diff --git a/syncapi/streams/stream_notificationdata.go b/syncapi/streams/stream_notificationdata.go
new file mode 100644
index 00000000..8ba9e07c
--- /dev/null
+++ b/syncapi/streams/stream_notificationdata.go
@@ -0,0 +1,55 @@
+package streams
+
+import (
+ "context"
+
+ "github.com/matrix-org/dendrite/syncapi/types"
+)
+
+type NotificationDataStreamProvider struct {
+ StreamProvider
+}
+
+func (p *NotificationDataStreamProvider) Setup() {
+ p.StreamProvider.Setup()
+
+ id, err := p.DB.MaxStreamPositionForNotificationData(context.Background())
+ if err != nil {
+ panic(err)
+ }
+ p.latest = id
+}
+
+func (p *NotificationDataStreamProvider) CompleteSync(
+ ctx context.Context,
+ req *types.SyncRequest,
+) types.StreamPosition {
+ return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
+}
+
+func (p *NotificationDataStreamProvider) IncrementalSync(
+ ctx context.Context,
+ req *types.SyncRequest,
+ from, to types.StreamPosition,
+) types.StreamPosition {
+ // We want counts for all possible rooms, so always start from zero.
+ countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to)
+ if err != nil {
+ req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed")
+ return from
+ }
+
+ // We're merely decorating existing rooms. Note that the Join map
+ // values are not pointers.
+ for roomID, jr := range req.Response.Rooms.Join {
+ counts := countsByRoom[roomID]
+ if counts == nil {
+ continue
+ }
+
+ jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount
+ jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount
+ req.Response.Rooms.Join[roomID] = jr
+ }
+ return to
+}
diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go
index c71095af..17951acb 100644
--- a/syncapi/streams/streams.go
+++ b/syncapi/streams/streams.go
@@ -12,13 +12,14 @@ import (
)
type Streams struct {
- PDUStreamProvider types.StreamProvider
- TypingStreamProvider types.StreamProvider
- ReceiptStreamProvider types.StreamProvider
- InviteStreamProvider types.StreamProvider
- SendToDeviceStreamProvider types.StreamProvider
- AccountDataStreamProvider types.StreamProvider
- DeviceListStreamProvider types.StreamProvider
+ PDUStreamProvider types.StreamProvider
+ TypingStreamProvider types.StreamProvider
+ ReceiptStreamProvider types.StreamProvider
+ InviteStreamProvider types.StreamProvider
+ SendToDeviceStreamProvider types.StreamProvider
+ AccountDataStreamProvider types.StreamProvider
+ DeviceListStreamProvider types.StreamProvider
+ NotificationDataStreamProvider types.StreamProvider
}
func NewSyncStreamProviders(
@@ -47,6 +48,9 @@ func NewSyncStreamProviders(
StreamProvider: StreamProvider{DB: d},
userAPI: userAPI,
},
+ NotificationDataStreamProvider: &NotificationDataStreamProvider{
+ StreamProvider: StreamProvider{DB: d},
+ },
DeviceListStreamProvider: &DeviceListStreamProvider{
StreamProvider: StreamProvider{DB: d},
rsAPI: rsAPI,
@@ -60,6 +64,7 @@ func NewSyncStreamProviders(
streams.InviteStreamProvider.Setup()
streams.SendToDeviceStreamProvider.Setup()
streams.AccountDataStreamProvider.Setup()
+ streams.NotificationDataStreamProvider.Setup()
streams.DeviceListStreamProvider.Setup()
return streams
@@ -67,12 +72,13 @@ func NewSyncStreamProviders(
func (s *Streams) Latest(ctx context.Context) types.StreamingToken {
return types.StreamingToken{
- PDUPosition: s.PDUStreamProvider.LatestPosition(ctx),
- TypingPosition: s.TypingStreamProvider.LatestPosition(ctx),
- ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx),
- InvitePosition: s.InviteStreamProvider.LatestPosition(ctx),
- SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx),
- AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),
- DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx),
+ PDUPosition: s.PDUStreamProvider.LatestPosition(ctx),
+ TypingPosition: s.TypingStreamProvider.LatestPosition(ctx),
+ ReceiptPosition: s.ReceiptStreamProvider.LatestPosition(ctx),
+ InvitePosition: s.InviteStreamProvider.LatestPosition(ctx),
+ SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx),
+ AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),
+ NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx),
+ DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx),
}
}
diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go
index ca35951a..2c9920d1 100644
--- a/syncapi/sync/requestpool.go
+++ b/syncapi/sync/requestpool.go
@@ -189,7 +189,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
currentPos.ApplyUpdates(userStreamListener.GetSyncPosition())
}
} else {
- syncReq.Log.Debugln("Responding to sync immediately")
+ syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
}
if syncReq.Since.IsEmpty() {
@@ -213,6 +213,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
syncReq.Context, syncReq,
),
+ NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
+ syncReq.Context, syncReq,
+ ),
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
syncReq.Context, syncReq,
),
@@ -244,6 +247,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
syncReq.Context, syncReq,
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
),
+ NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
+ syncReq.Context, syncReq,
+ syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
+ ),
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index 72462459..cb9890ff 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/consumers"
"github.com/matrix-org/dendrite/syncapi/notifier"
+ "github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
@@ -64,6 +65,18 @@ func AddPublicRoutes(
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier)
+ userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
+ JetStream: js,
+ Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
+ }
+
+ userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
+ JetStream: js,
+ Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
+ }
+
+ _ = userAPIReadUpdateProducer
+
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
js, keyAPI, rsAPI, syncDB, notifier,
@@ -75,7 +88,7 @@ func AddPublicRoutes(
roomConsumer := consumers.NewOutputRoomEventConsumer(
process, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
- streams.InviteStreamProvider, rsAPI,
+ streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer,
)
if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer")
@@ -83,11 +96,19 @@ func AddPublicRoutes(
clientConsumer := consumers.NewOutputClientDataConsumer(
process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider,
+ userAPIReadUpdateProducer,
)
if err = clientConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start client data consumer")
}
+ notificationConsumer := consumers.NewOutputNotificationDataConsumer(
+ process, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider,
+ )
+ if err = notificationConsumer.Start(); err != nil {
+ logrus.WithError(err).Panicf("failed to start notification data consumer")
+ }
+
typingConsumer := consumers.NewOutputTypingEventConsumer(
process, cfg, js, syncDB, eduCache, notifier, streams.TypingStreamProvider,
)
@@ -104,6 +125,7 @@ func AddPublicRoutes(
receiptConsumer := consumers.NewOutputReceiptEventConsumer(
process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider,
+ userAPIReadUpdateProducer,
)
if err = receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start receipts consumer")
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index c2e8ed01..4150e6c9 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -95,13 +95,14 @@ const (
)
type StreamingToken struct {
- PDUPosition StreamPosition
- TypingPosition StreamPosition
- ReceiptPosition StreamPosition
- SendToDevicePosition StreamPosition
- InvitePosition StreamPosition
- AccountDataPosition StreamPosition
- DeviceListPosition StreamPosition
+ PDUPosition StreamPosition
+ TypingPosition StreamPosition
+ ReceiptPosition StreamPosition
+ SendToDevicePosition StreamPosition
+ InvitePosition StreamPosition
+ AccountDataPosition StreamPosition
+ DeviceListPosition StreamPosition
+ NotificationDataPosition StreamPosition
}
// This will be used as a fallback by json.Marshal.
@@ -117,10 +118,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
func (t StreamingToken) String() string {
posStr := fmt.Sprintf(
- "s%d_%d_%d_%d_%d_%d_%d",
+ "s%d_%d_%d_%d_%d_%d_%d_%d",
t.PDUPosition, t.TypingPosition,
t.ReceiptPosition, t.SendToDevicePosition,
- t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition,
+ t.InvitePosition, t.AccountDataPosition,
+ t.DeviceListPosition, t.NotificationDataPosition,
)
return posStr
}
@@ -142,12 +144,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
return true
case t.DeviceListPosition > other.DeviceListPosition:
return true
+ case t.NotificationDataPosition > other.NotificationDataPosition:
+ return true
}
return false
}
func (t *StreamingToken) IsEmpty() bool {
- return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0
+ return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition+t.NotificationDataPosition == 0
}
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
@@ -185,6 +189,9 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
if other.DeviceListPosition > t.DeviceListPosition {
t.DeviceListPosition = other.DeviceListPosition
}
+ if other.NotificationDataPosition > t.NotificationDataPosition {
+ t.NotificationDataPosition = other.NotificationDataPosition
+ }
}
type TopologyToken struct {
@@ -277,7 +284,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
// s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions
tok = strings.Split(tok, ".")[0]
parts := strings.Split(tok[1:], "_")
- var positions [7]StreamPosition
+ var positions [8]StreamPosition
for i, p := range parts {
if i >= len(positions) {
break
@@ -291,13 +298,14 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
positions[i] = StreamPosition(pos)
}
token = StreamingToken{
- PDUPosition: positions[0],
- TypingPosition: positions[1],
- ReceiptPosition: positions[2],
- SendToDevicePosition: positions[3],
- InvitePosition: positions[4],
- AccountDataPosition: positions[5],
- DeviceListPosition: positions[6],
+ PDUPosition: positions[0],
+ TypingPosition: positions[1],
+ ReceiptPosition: positions[2],
+ SendToDevicePosition: positions[3],
+ InvitePosition: positions[4],
+ AccountDataPosition: positions[5],
+ DeviceListPosition: positions[6],
+ NotificationDataPosition: positions[7],
}
return token, nil
}
@@ -383,6 +391,10 @@ type JoinResponse struct {
AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data"`
+ UnreadNotifications struct {
+ HighlightCount int `json:"highlight_count"`
+ NotificationCount int `json:"notification_count"`
+ } `json:"unread_notifications"`
}
// NewJoinResponse creates an empty response with initialised arrays.
@@ -462,3 +474,16 @@ type Peek struct {
New bool
Deleted bool
}
+
+type ReadUpdate struct {
+ UserID string `json:"user_id"`
+ RoomID string `json:"room_id"`
+ Read StreamPosition `json:"read,omitempty"`
+ FullyRead StreamPosition `json:"fully_read,omitempty"`
+}
+
+// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
+type StreamedEvent struct {
+ Event *gomatrixserverlib.HeaderedEvent `json:"event"`
+ StreamPosition StreamPosition `json:"stream_position"`
+}
diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go
index cda178b3..ff78bfb9 100644
--- a/syncapi/types/types_test.go
+++ b/syncapi/types/types_test.go
@@ -9,10 +9,10 @@ import (
func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{
- "s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(),
- "s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(),
- "s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(),
- "t3_1": TopologyToken{3, 1}.String(),
+ "s4_0_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0}.String(),
+ "s3_1_0_0_0_0_2_0": StreamingToken{3, 1, 0, 0, 0, 0, 2, 0}.String(),
+ "s3_1_2_3_5_0_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0, 0}.String(),
+ "t3_1": TopologyToken{3, 1}.String(),
}
for a, b := range shouldPass {
diff --git a/sytest-blacklist b/sytest-blacklist
index e8617dcd..7f518b21 100644
--- a/sytest-blacklist
+++ b/sytest-blacklist
@@ -30,3 +30,9 @@ Local device key changes appear in /keys/changes
Remove group category
Remove group role
+# Flakey
+AS-ghosted users can use rooms themselves
+
+# Flakey, need additional investigation
+Messages that notify from another user increment notification_count
+Messages that highlight from another user increment unread highlight count
diff --git a/sytest-whitelist b/sytest-whitelist
index 3e38176f..602f8646 100644
--- a/sytest-whitelist
+++ b/sytest-whitelist
@@ -339,17 +339,17 @@ Existing members see new members' join events
Inbound federation can receive events
Inbound federation can receive redacted events
Can logout current device
-Can send a message directly to a device using PUT /sendToDevice
-Can recv a device message using /sync
-Can recv device messages until they are acknowledged
-Device messages with the same txn_id are deduplicated
-Device messages wake up /sync
-Can recv device messages over federation
-Device messages over federation wake up /sync
-Can send messages with a wildcard device id
-Can send messages with a wildcard device id to two devices
-Wildcard device messages wake up /sync
-Wildcard device messages over federation wake up /sync
+Can send a message directly to a device using PUT /sendToDevice
+Can recv a device message using /sync
+Can recv device messages until they are acknowledged
+Device messages with the same txn_id are deduplicated
+Device messages wake up /sync
+Can recv device messages over federation
+Device messages over federation wake up /sync
+Can send messages with a wildcard device id
+Can send messages with a wildcard device id to two devices
+Wildcard device messages wake up /sync
+Wildcard device messages over federation wake up /sync
Can send a to-device message to two users which both receive it using /sync
User can create and send/receive messages in a room with version 6
local user can join room with version 6
@@ -477,7 +477,7 @@ Federation key API can act as a notary server via a GET request
Inbound /make_join rejects attempts to join rooms where all users have left
Inbound federation rejects invites which include invalid JSON for room version 6
Inbound federation rejects invite rejections which include invalid JSON for room version 6
-GET /capabilities is present and well formed for registered user
+GET /capabilities is present and well formed for registered user
m.room.history_visibility == "joined" allows/forbids appropriately for Guest users
m.room.history_visibility == "joined" allows/forbids appropriately for Real users
POST rejects invalid utf-8 in JSON
@@ -588,6 +588,59 @@ User can invite remote user to room with version 9
Remote user can backfill in a room with version 9
Can reject invites over federation for rooms with version 9
Can receive redactions from regular users over federation in room version 9
+Pushers created with a different access token are deleted on password change
+Pushers created with a the same access token are not deleted on password change
+Can fetch a user's pushers
+Can add global push rule for room
+Can add global push rule for sender
+Can add global push rule for content
+Can add global push rule for override
+Can add global push rule for underride
+Can add global push rule for content
+New rules appear before old rules by default
+Can add global push rule before an existing rule
+Can add global push rule after an existing rule
+Can delete a push rule
+Can disable a push rule
+Adding the same push rule twice is idempotent
+Can change the actions of default rules
+Can change the actions of a user specified rule
+Adding a push rule wakes up an incremental /sync
+Disabling a push rule wakes up an incremental /sync
+Enabling a push rule wakes up an incremental /sync
+Setting actions for a push rule wakes up an incremental /sync
+Can enable/disable default rules
+Trying to add push rule with missing template fails with 400
+Trying to add push rule with missing rule_id fails with 400
+Trying to add push rule with empty rule_id fails with 400
+Trying to add push rule with invalid template fails with 400
+Trying to add push rule with rule_id with slashes fails with 400
+Trying to add push rule with override rule without conditions fails with 400
+Trying to add push rule with underride rule without conditions fails with 400
+Trying to add push rule with condition without kind fails with 400
+Trying to add push rule with content rule without pattern fails with 400
+Trying to add push rule with no actions fails with 400
+Trying to add push rule with invalid action fails with 400
+Trying to add push rule with invalid attr fails with 400
+Trying to add push rule with invalid value for enabled fails with 400
+Trying to get push rules with no trailing slash fails with 400
+Trying to get push rules with scope without trailing slash fails with 400
+Trying to get push rules with template without tailing slash fails with 400
+Trying to get push rules with unknown scope fails with 400
+Trying to get push rules with unknown template fails with 400
+Trying to get push rules with unknown attribute fails with 400
+Getting push rules doesn't corrupt the cache SYN-390
+Test that a message is pushed
+Invites are pushed
+Rooms with names are correctly named in pushes
+Rooms with canonical alias are correctly named in pushed
+Rooms with many users are correctly pushed
+Don't get pushed for rooms you've muted
+Rejected events are not pushed
+Test that rejected pushers are removed.
+Notifications can be viewed with GET /notifications
+Trying to add push rule with no scope fails with 400
+Trying to add push rule with invalid scope fails with 400
Forward extremities remain so even after the next events are populated as outliers
If a device list update goes missing, the server resyncs on the next one
uploading self-signing key notifies over federation
@@ -607,4 +660,4 @@ registration accepts non-ascii passwords
registration with inhibit_login inhibits login
The operation must be consistent through an interactive authentication session
Multiple calls to /sync should not cause 500 errors
-
+/context/ with lazy_load_members filter works
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 2be662e5..e9cdbe01 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/internal/pushrules"
)
// UserInternalAPI is the internal API for information about users and devices.
@@ -28,6 +29,7 @@ type UserInternalAPI interface {
LoginTokenInternalAPI
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
+
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
@@ -37,6 +39,10 @@ type UserInternalAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
+ PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
+ PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
+ PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
+
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
@@ -45,6 +51,9 @@ type UserInternalAPI interface {
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
+ QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
+ QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
+ QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
}
type PerformKeyBackupRequest struct {
@@ -424,3 +433,77 @@ const (
// AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4
)
+
+type QueryPushersRequest struct {
+ Localpart string
+}
+
+type QueryPushersResponse struct {
+ Pushers []Pusher `json:"pushers"`
+}
+
+type PerformPusherSetRequest struct {
+ Pusher // Anonymous field because that's how clientapi unmarshals it.
+ Localpart string
+ Append bool `json:"append"`
+}
+
+type PerformPusherDeletionRequest struct {
+ Localpart string
+ SessionID int64
+}
+
+// Pusher represents a push notification subscriber
+type Pusher struct {
+ SessionID int64 `json:"session_id,omitempty"`
+ PushKey string `json:"pushkey"`
+ PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
+ Kind PusherKind `json:"kind"`
+ AppID string `json:"app_id"`
+ AppDisplayName string `json:"app_display_name"`
+ DeviceDisplayName string `json:"device_display_name"`
+ ProfileTag string `json:"profile_tag"`
+ Language string `json:"lang"`
+ Data map[string]interface{} `json:"data"`
+}
+
+type PusherKind string
+
+const (
+ EmailKind PusherKind = "email"
+ HTTPKind PusherKind = "http"
+)
+
+type PerformPushRulesPutRequest struct {
+ UserID string `json:"user_id"`
+ RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
+}
+
+type QueryPushRulesRequest struct {
+ UserID string `json:"user_id"`
+}
+
+type QueryPushRulesResponse struct {
+ RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
+}
+
+type QueryNotificationsRequest struct {
+ Localpart string `json:"localpart"` // Required.
+ From string `json:"from,omitempty"`
+ Limit int `json:"limit,omitempty"`
+ Only string `json:"only,omitempty"`
+}
+
+type QueryNotificationsResponse struct {
+ NextToken string `json:"next_token"`
+ Notifications []*Notification `json:"notifications"` // Required.
+}
+
+type Notification struct {
+ Actions []*pushrules.Action `json:"actions"` // Required.
+ Event gomatrixserverlib.ClientEvent `json:"event"` // Required.
+ ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional.
+ Read bool `json:"read"` // Required.
+ RoomID string `json:"room_id"` // Required.
+ TS gomatrixserverlib.Timestamp `json:"ts"` // Required.
+}
diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go
index aa069f40..9334f445 100644
--- a/userapi/api/api_trace.go
+++ b/userapi/api/api_trace.go
@@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor
util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res))
return err
}
+func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error {
+ err := t.Impl.PerformPusherSet(ctx, req, res)
+ util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error {
+ err := t.Impl.PerformPusherDeletion(ctx, req, res)
+ util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error {
+ err := t.Impl.PerformPushRulesPut(ctx, req, res)
+ util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
+ return err
+}
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
@@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO
util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res))
return err
}
+func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error {
+ err := t.Impl.QueryPushers(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error {
+ err := t.Impl.QueryPushRules(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res))
+ return err
+}
+func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error {
+ err := t.Impl.QueryNotifications(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res))
+ return err
+}
func js(thing interface{}) string {
b, err := json.Marshal(thing)
diff --git a/userapi/consumers/syncapi_readupdate.go b/userapi/consumers/syncapi_readupdate.go
new file mode 100644
index 00000000..2e58020b
--- /dev/null
+++ b/userapi/consumers/syncapi_readupdate.go
@@ -0,0 +1,136 @@
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ uapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/producers"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/util"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+type OutputReadUpdateConsumer struct {
+ ctx context.Context
+ cfg *config.UserAPI
+ jetstream nats.JetStreamContext
+ durable string
+ db storage.Database
+ pgClient pushgateway.Client
+ ServerName gomatrixserverlib.ServerName
+ topic string
+ userAPI uapi.UserInternalAPI
+ syncProducer *producers.SyncAPI
+}
+
+func NewOutputReadUpdateConsumer(
+ process *process.ProcessContext,
+ cfg *config.UserAPI,
+ js nats.JetStreamContext,
+ store storage.Database,
+ pgClient pushgateway.Client,
+ userAPI uapi.UserInternalAPI,
+ syncProducer *producers.SyncAPI,
+) *OutputReadUpdateConsumer {
+ return &OutputReadUpdateConsumer{
+ ctx: process.Context(),
+ cfg: cfg,
+ jetstream: js,
+ db: store,
+ ServerName: cfg.Matrix.ServerName,
+ durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
+ topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
+ pgClient: pgClient,
+ userAPI: userAPI,
+ syncProducer: syncProducer,
+ }
+}
+
+func (s *OutputReadUpdateConsumer) Start() error {
+ if err := jetstream.JetStreamConsumer(
+ s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
+ nats.DeliverAll(), nats.ManualAck(),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
+ var read types.ReadUpdate
+ if err := json.Unmarshal(msg.Data, &read); err != nil {
+ log.WithError(err).Error("userapi clientapi consumer: message parse failure")
+ return true
+ }
+ if read.FullyRead == 0 && read.Read == 0 {
+ return true
+ }
+
+ userID := string(msg.Header.Get(jetstream.UserID))
+ roomID := string(msg.Header.Get(jetstream.RoomID))
+
+ localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
+ return true
+ }
+ if domain != s.ServerName {
+ log.Error("userapi clientapi consumer: not a local user")
+ return true
+ }
+
+ log := log.WithFields(log.Fields{
+ "room_id": roomID,
+ "user_id": userID,
+ })
+ log.Tracef("Received read update from sync API: %#v", read)
+
+ if read.Read > 0 {
+ updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
+ if err != nil {
+ log.WithError(err).Error("userapi EDU consumer")
+ return false
+ }
+
+ if updated {
+ if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
+ log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
+ return false
+ }
+ if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
+ log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
+ return false
+ }
+ }
+ }
+
+ if read.FullyRead > 0 {
+ deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
+ if err != nil {
+ log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
+ return false
+ }
+
+ if deleted {
+ if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
+ log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
+ return false
+ }
+
+ if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
+ log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
+ return false
+ }
+ }
+ }
+
+ return true
+}
diff --git a/userapi/consumers/syncapi_streamevent.go b/userapi/consumers/syncapi_streamevent.go
new file mode 100644
index 00000000..11081327
--- /dev/null
+++ b/userapi/consumers/syncapi_streamevent.go
@@ -0,0 +1,588 @@
+package consumers
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/internal/pushrules"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/setup/process"
+ "github.com/matrix-org/dendrite/syncapi/types"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/producers"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/util"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+type OutputStreamEventConsumer struct {
+ ctx context.Context
+ cfg *config.UserAPI
+ userAPI api.UserInternalAPI
+ rsAPI rsapi.RoomserverInternalAPI
+ jetstream nats.JetStreamContext
+ durable string
+ db storage.Database
+ topic string
+ pgClient pushgateway.Client
+ syncProducer *producers.SyncAPI
+}
+
+func NewOutputStreamEventConsumer(
+ process *process.ProcessContext,
+ cfg *config.UserAPI,
+ js nats.JetStreamContext,
+ store storage.Database,
+ pgClient pushgateway.Client,
+ userAPI api.UserInternalAPI,
+ rsAPI rsapi.RoomserverInternalAPI,
+ syncProducer *producers.SyncAPI,
+) *OutputStreamEventConsumer {
+ return &OutputStreamEventConsumer{
+ ctx: process.Context(),
+ cfg: cfg,
+ jetstream: js,
+ db: store,
+ durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"),
+ topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
+ pgClient: pgClient,
+ userAPI: userAPI,
+ rsAPI: rsAPI,
+ syncProducer: syncProducer,
+ }
+}
+
+func (s *OutputStreamEventConsumer) Start() error {
+ if err := jetstream.JetStreamConsumer(
+ s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
+ nats.DeliverAll(), nats.ManualAck(),
+ ); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
+ var output types.StreamedEvent
+ output.Event = &gomatrixserverlib.HeaderedEvent{}
+ if err := json.Unmarshal(msg.Data, &output); err != nil {
+ log.WithError(err).Errorf("userapi consumer: message parse failure")
+ return true
+ }
+ if output.Event.Event == nil {
+ log.Errorf("userapi consumer: expected event")
+ return true
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": output.Event.EventID(),
+ "event_type": output.Event.Type(),
+ "stream_pos": output.StreamPosition,
+ }).Tracef("Received message from sync API: %#v", output)
+
+ if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil {
+ log.WithFields(log.Fields{
+ "event_id": output.Event.EventID(),
+ }).WithError(err).Errorf("userapi consumer: process room event failure")
+ }
+
+ return true
+}
+
+func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error {
+ members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
+ if err != nil {
+ return fmt.Errorf("s.localRoomMembers: %w", err)
+ }
+
+ if event.Type() == gomatrixserverlib.MRoomMember {
+ cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll)
+ var member *localMembership
+ member, err = newLocalMembership(&cevent)
+ if err != nil {
+ return fmt.Errorf("newLocalMembership: %w", err)
+ }
+ if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName {
+ // localRoomMembers only adds joined members. An invite
+ // should also be pushed to the target user.
+ members = append(members, member)
+ }
+ }
+
+ // TODO: run in parallel with localRoomMembers.
+ roomName, err := s.roomName(ctx, event)
+ if err != nil {
+ return fmt.Errorf("s.roomName: %w", err)
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "num_members": len(members),
+ "room_size": roomSize,
+ }).Tracef("Notifying members")
+
+ // Notification.UserIsTarget is a per-member field, so we
+ // cannot group all users in a single request.
+ //
+ // TODO: does it have to be set? It's not required, and
+ // removing it means we can send all notifications to
+ // e.g. Element's Push gateway in one go.
+ for _, mem := range members {
+ if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil {
+ log.WithFields(log.Fields{
+ "localpart": mem.Localpart,
+ }).WithError(err).Debugf("Unable to push to local user")
+ continue
+ }
+ }
+
+ return nil
+}
+
+type localMembership struct {
+ gomatrixserverlib.MemberContent
+ UserID string
+ Localpart string
+ Domain gomatrixserverlib.ServerName
+}
+
+func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) {
+ if event.StateKey == nil {
+ return nil, fmt.Errorf("missing state_key")
+ }
+
+ var member localMembership
+ if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil {
+ return nil, err
+ }
+
+ localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey)
+ if err != nil {
+ return nil, err
+ }
+
+ member.UserID = *event.StateKey
+ member.Localpart = localpart
+ member.Domain = domain
+ return &member, nil
+}
+
+// localRoomMembers fetches the current local members of a room, and
+// the total number of members.
+func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
+ req := &rsapi.QueryMembershipsForRoomRequest{
+ RoomID: roomID,
+ JoinedOnly: true,
+ }
+ var res rsapi.QueryMembershipsForRoomResponse
+
+ // XXX: This could potentially race if the state for the event is not known yet
+ // e.g. the event came over federation but we do not have the full state persisted.
+ if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
+ return nil, 0, err
+ }
+
+ var members []*localMembership
+ var ntotal int
+ for _, event := range res.JoinEvents {
+ member, err := newLocalMembership(&event)
+ if err != nil {
+ log.WithError(err).Errorf("Parsing MemberContent")
+ continue
+ }
+ if member.Membership != gomatrixserverlib.Join {
+ continue
+ }
+ if member.Domain != s.cfg.Matrix.ServerName {
+ continue
+ }
+
+ ntotal++
+ members = append(members, member)
+ }
+
+ return members, ntotal, nil
+}
+
+// roomName returns the name in the event (if type==m.room.name), or
+// looks it up in roomserver. If there is no name,
+// m.room.canonical_alias is consulted. Returns an empty string if the
+// room has no name.
+func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
+ if event.Type() == gomatrixserverlib.MRoomName {
+ name, err := unmarshalRoomName(event)
+ if err != nil {
+ return "", err
+ }
+
+ if name != "" {
+ return name, nil
+ }
+ }
+
+ req := &rsapi.QueryCurrentStateRequest{
+ RoomID: event.RoomID(),
+ StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple},
+ }
+ var res rsapi.QueryCurrentStateResponse
+
+ if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil {
+ return "", nil
+ }
+
+ if eventS := res.StateEvents[roomNameTuple]; eventS != nil {
+ return unmarshalRoomName(eventS)
+ }
+
+ if event.Type() == gomatrixserverlib.MRoomCanonicalAlias {
+ alias, err := unmarshalCanonicalAlias(event)
+ if err != nil {
+ return "", err
+ }
+
+ if alias != "" {
+ return alias, nil
+ }
+ }
+
+ if event = res.StateEvents[canonicalAliasTuple]; event != nil {
+ return unmarshalCanonicalAlias(event)
+ }
+
+ return "", nil
+}
+
+var (
+ canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias}
+ roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName}
+)
+
+func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) {
+ var nc eventutil.NameContent
+ if err := json.Unmarshal(event.Content(), &nc); err != nil {
+ return "", fmt.Errorf("unmarshaling NameContent: %w", err)
+ }
+
+ return nc.Name, nil
+}
+
+func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) {
+ var cac eventutil.CanonicalAliasContent
+ if err := json.Unmarshal(event.Content(), &cac); err != nil {
+ return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err)
+ }
+
+ return cac.Alias, nil
+}
+
+// notifyLocal finds the right push actions for a local user, given an event.
+func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error {
+ actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
+ if err != nil {
+ return err
+ }
+ a, tweaks, err := pushrules.ActionsToTweaks(actions)
+ if err != nil {
+ return err
+ }
+ // TODO: support coalescing.
+ if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "localpart": mem.Localpart,
+ }).Tracef("Push rule evaluation rejected the event")
+ return nil
+ }
+
+ devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
+ if err != nil {
+ return err
+ }
+
+ n := &api.Notification{
+ Actions: actions,
+ // UNSPEC: the spec doesn't say this is a ClientEvent, but the
+ // fields seem to match. room_id should be missing, which
+ // matches the behaviour of FormatSync.
+ Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync),
+ // TODO: this is per-device, but it's not part of the primary
+ // key. So inserting one notification per profile tag doesn't
+ // make sense. What is this supposed to be? Sytests require it
+ // to "work", but they only use a single device.
+ ProfileTag: profileTag,
+ RoomID: event.RoomID(),
+ TS: gomatrixserverlib.AsTimestamp(time.Now()),
+ }
+ if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil {
+ return err
+ }
+
+ if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
+ return err
+ }
+
+ // We do this after InsertNotification. Thus, this should always return >=1.
+ userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "localpart": mem.Localpart,
+ "num_urls": len(devicesByURLAndFormat),
+ "num_unread": userNumUnreadNotifs,
+ }).Tracef("Notifying single member")
+
+ // Push gateways are out of our control, and we cannot risk
+ // looking up the server on a misbehaving push gateway. Each user
+ // receives a goroutine now that all internal API calls have been
+ // made.
+ //
+ // TODO: think about bounding this to one per user, and what
+ // ordering guarantees we must provide.
+ go func() {
+ // This background processing cannot be tied to a request.
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ var rejected []*pushgateway.Device
+ for url, fmts := range devicesByURLAndFormat {
+ for format, devices := range fmts {
+ // TODO: support "email".
+ if !strings.HasPrefix(url, "http") {
+ continue
+ }
+
+ // UNSPEC: the specification suggests there can be
+ // more than one device per request. There is at least
+ // one Sytest that expects one HTTP request per
+ // device, rather than per URL. For now, we must
+ // notify each one separately.
+ for _, dev := range devices {
+ rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs))
+ if err != nil {
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "localpart": mem.Localpart,
+ }).WithError(err).Errorf("Unable to notify HTTP pusher")
+ continue
+ }
+ rejected = append(rejected, rej...)
+ }
+ }
+ }
+
+ if len(rejected) > 0 {
+ s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
+ }
+ }()
+
+ return nil
+}
+
+// evaluatePushRules fetches and evaluates the push rules of a local
+// user. Returns actions (including dont_notify).
+func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
+ if event.Sender() == mem.UserID {
+ // SPEC: Homeservers MUST NOT notify the Push Gateway for
+ // events that the user has sent themselves.
+ return nil, nil
+ }
+
+ var res api.QueryPushRulesResponse
+ if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil {
+ return nil, err
+ }
+
+ ec := &ruleSetEvalContext{
+ ctx: ctx,
+ rsAPI: s.rsAPI,
+ mem: mem,
+ roomID: event.RoomID(),
+ roomSize: roomSize,
+ }
+ eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global)
+ rule, err := eval.MatchEvent(event.Event)
+ if err != nil {
+ return nil, err
+ }
+ if rule == nil {
+ // SPEC: If no rules match an event, the homeserver MUST NOT
+ // notify the Push Gateway for that event.
+ return nil, err
+ }
+
+ log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "room_id": event.RoomID(),
+ "localpart": mem.Localpart,
+ "rule_id": rule.RuleID,
+ }).Tracef("Matched a push rule")
+
+ return rule.Actions, nil
+}
+
+type ruleSetEvalContext struct {
+ ctx context.Context
+ rsAPI rsapi.RoomserverInternalAPI
+ mem *localMembership
+ roomID string
+ roomSize int
+}
+
+func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName }
+
+func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
+
+func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
+ req := &rsapi.QueryLatestEventsAndStateRequest{
+ RoomID: rse.roomID,
+ StateToFetch: []gomatrixserverlib.StateKeyTuple{
+ {EventType: gomatrixserverlib.MRoomPowerLevels},
+ },
+ }
+ var res rsapi.QueryLatestEventsAndStateResponse
+ if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil {
+ return false, err
+ }
+ for _, ev := range res.StateEvents {
+ if ev.Type() != gomatrixserverlib.MRoomPowerLevels {
+ continue
+ }
+
+ plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event)
+ if err != nil {
+ return false, err
+ }
+ return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
+ }
+ return true, nil
+}
+
+// localPushDevices pushes to the configured devices of a local
+// user. The map keys are [url][format].
+func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
+ pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
+ if err != nil {
+ return nil, "", err
+ }
+
+ var profileTag string
+ devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices))
+ for _, pusherDevice := range pusherDevices {
+ if profileTag == "" {
+ profileTag = pusherDevice.Pusher.ProfileTag
+ }
+
+ url := pusherDevice.URL
+ if devicesByURL[url] == nil {
+ devicesByURL[url] = make(map[string][]*pushgateway.Device, 2)
+ }
+ devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device)
+ }
+
+ return devicesByURL, profileTag, nil
+}
+
+// notifyHTTP performs a notificatation to a Push Gateway.
+func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
+ logger := log.WithFields(log.Fields{
+ "event_id": event.EventID(),
+ "url": url,
+ "localpart": localpart,
+ "num_devices": len(devices),
+ })
+
+ var req pushgateway.NotifyRequest
+ switch format {
+ case "event_id_only":
+ req = pushgateway.NotifyRequest{
+ Notification: pushgateway.Notification{
+ Counts: &pushgateway.Counts{},
+ Devices: devices,
+ EventID: event.EventID(),
+ RoomID: event.RoomID(),
+ },
+ }
+
+ default:
+ req = pushgateway.NotifyRequest{
+ Notification: pushgateway.Notification{
+ Content: event.Content(),
+ Counts: &pushgateway.Counts{
+ Unread: userNumUnreadNotifs,
+ },
+ Devices: devices,
+ EventID: event.EventID(),
+ ID: event.EventID(),
+ RoomID: event.RoomID(),
+ RoomName: roomName,
+ Sender: event.Sender(),
+ Type: event.Type(),
+ },
+ }
+ if mem, err := event.Membership(); err == nil {
+ req.Notification.Membership = mem
+ }
+ if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) {
+ req.Notification.UserIsTarget = true
+ }
+ }
+
+ logger.Debugf("Notifying push gateway %s", url)
+ var res pushgateway.NotifyResponse
+ if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
+ logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
+ return nil, err
+ }
+ logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result")
+
+ if len(res.Rejected) == 0 {
+ return nil, nil
+ }
+
+ devMap := make(map[string]*pushgateway.Device, len(devices))
+ for _, d := range devices {
+ devMap[d.PushKey] = d
+ }
+ rejected := make([]*pushgateway.Device, 0, len(res.Rejected))
+ for _, pushKey := range res.Rejected {
+ d := devMap[pushKey]
+ if d != nil {
+ rejected = append(rejected, d)
+ }
+ }
+
+ return rejected, nil
+}
+
+// deleteRejectedPushers deletes the pushers associated with the given devices.
+func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id0": devices[0].AppID,
+ "num_devices": len(devices),
+ }).Warnf("Deleting pushers rejected by the HTTP push gateway")
+
+ for _, d := range devices {
+ if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ }).WithError(err).Errorf("Unable to delete rejected pusher")
+ }
+ }
+}
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index f54cc613..7a42fc60 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -20,6 +20,8 @@ import (
"encoding/json"
"errors"
"fmt"
+ "strconv"
+ "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@@ -27,16 +29,22 @@ import (
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
+ "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
type UserInternalAPI struct {
- DB storage.Database
- ServerName gomatrixserverlib.ServerName
+ DB storage.Database
+ SyncProducer *producers.SyncAPI
+
+ DisableTLSValidation bool
+ ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
KeyAPI keyapi.KeyInternalAPI
@@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
}
res.Keys = result
}
+
+func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
+ if req.Limit == 0 || req.Limit > 1000 {
+ req.Limit = 1000
+ }
+
+ var fromID int64
+ var err error
+ if req.From != "" {
+ fromID, err = strconv.ParseInt(req.From, 10, 64)
+ if err != nil {
+ return fmt.Errorf("QueryNotifications: parsing 'from': %w", err)
+ }
+ }
+ var filter tables.NotificationFilter = tables.AllNotifications
+ if req.Only == "highlight" {
+ filter = tables.HighlightNotifications
+ }
+ notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
+ if err != nil {
+ return err
+ }
+ if notifs == nil {
+ // This ensures empty is JSON-encoded as [] instead of null.
+ notifs = []*api.Notification{}
+ }
+ res.Notifications = notifs
+ if lastID >= 0 {
+ res.NextToken = strconv.FormatInt(lastID+1, 10)
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error {
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "localpart": req.Localpart,
+ "pushkey": req.Pusher.PushKey,
+ "display_name": req.Pusher.AppDisplayName,
+ }).Info("PerformPusherCreation")
+ if !req.Append {
+ err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey)
+ if err != nil {
+ return err
+ }
+ }
+ if req.Pusher.Kind == "" {
+ return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
+ }
+ if req.Pusher.PushKeyTS == 0 {
+ req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
+ }
+ return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
+}
+
+func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
+ pushers, err := a.DB.GetPushers(ctx, req.Localpart)
+ if err != nil {
+ return err
+ }
+ for i := range pushers {
+ logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
+ if pushers[i].SessionID != req.SessionID {
+ err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
+ var err error
+ res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
+ return err
+}
+
+func (a *UserInternalAPI) PerformPushRulesPut(
+ ctx context.Context,
+ req *api.PerformPushRulesPutRequest,
+ _ *struct{},
+) error {
+ bs, err := json.Marshal(&req.RuleSets)
+ if err != nil {
+ return err
+ }
+ userReq := api.InputAccountDataRequest{
+ UserID: req.UserID,
+ DataType: pushRulesAccountDataType,
+ AccountData: json.RawMessage(bs),
+ }
+ var userRes api.InputAccountDataResponse // empty
+ if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
+ return err
+ }
+
+ if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil {
+ util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed")
+ }
+
+ return nil
+}
+
+func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
+ userReq := api.QueryAccountDataRequest{
+ UserID: req.UserID,
+ DataType: pushRulesAccountDataType,
+ }
+ var userRes api.QueryAccountDataResponse
+ if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil {
+ return err
+ }
+ bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType]
+ if ok {
+ // Legacy Dendrite users will have completely empty push rules, so we should
+ // detect that situation and set some defaults.
+ var rules struct {
+ G struct {
+ Content []json.RawMessage `json:"content"`
+ Override []json.RawMessage `json:"override"`
+ Room []json.RawMessage `json:"room"`
+ Sender []json.RawMessage `json:"sender"`
+ Underride []json.RawMessage `json:"underride"`
+ } `json:"global"`
+ }
+ if err := json.Unmarshal([]byte(bs), &rules); err == nil {
+ count := len(rules.G.Content) + len(rules.G.Override) +
+ len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride)
+ ok = count > 0
+ }
+ }
+ if !ok {
+ // If we didn't find any default push rules then we should just generate some
+ // fresh ones.
+ localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
+ if err != nil {
+ return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
+ }
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName)
+ prbs, err := json.Marshal(pushRuleSets)
+ if err != nil {
+ return fmt.Errorf("failed to marshal default push rules: %w", err)
+ }
+ if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil {
+ return fmt.Errorf("failed to save default push rules: %w", err)
+ }
+ res.RuleSets = pushRuleSets
+ return nil
+ }
+ var data pushrules.AccountRuleSets
+ if err := json.Unmarshal([]byte(bs), &data); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed")
+ return err
+ }
+ res.RuleSets = &data
+ return nil
+}
+
+const pushRulesAccountDataType = "m.push_rules"
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 1599d463..8ec649ad 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -37,6 +37,9 @@ const (
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
PerformKeyBackupPath = "/userapi/performKeyBackup"
+ PerformPusherSetPath = "/pushserver/performPusherSet"
+ PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
+ PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile"
@@ -46,6 +49,9 @@ const (
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
+ QueryPushersPath = "/pushserver/queryPushers"
+ QueryPushRulesPath = "/pushserver/queryPushRules"
+ QueryNotificationsPath = "/pushserver/queryNotifications"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@@ -249,3 +255,58 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query
res.Error = err.Error()
}
}
+
+func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
+ defer span.Finish()
+
+ return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
+}
+
+func (h *httpUserInternalAPI) PerformPusherSet(
+ ctx context.Context,
+ request *api.PerformPusherSetRequest,
+ response *struct{},
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPusherSetPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPusherDeletionPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
+
+func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
+ defer span.Finish()
+
+ apiURL := h.apiURL + QueryPushersPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
+
+func (h *httpUserInternalAPI) PerformPushRulesPut(
+ ctx context.Context,
+ request *api.PerformPushRulesPutRequest,
+ response *struct{},
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPushRulesPutPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
+func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
+ defer span.Finish()
+
+ apiURL := h.apiURL + QueryPushRulesPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index d00ee042..526f9957 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(QueryNotificationsPath,
+ httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
+ var request api.QueryNotificationsRequest
+ var response api.QueryNotificationsResponse
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(PerformPusherSetPath,
+ httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPusherSetRequest{}
+ response := struct{}{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+ internalAPIMux.Handle(PerformPusherDeletionPath,
+ httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPusherDeletionRequest{}
+ response := struct{}{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(QueryPushersPath,
+ httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse {
+ request := api.QueryPushersRequest{}
+ response := api.QueryPushersResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(PerformPushRulesPutPath,
+ httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPushRulesPutRequest{}
+ response := struct{}{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
+
+ internalAPIMux.Handle(QueryPushRulesPath,
+ httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse {
+ request := api.QueryPushRulesRequest{}
+ response := api.QueryPushRulesResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
}
diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go
new file mode 100644
index 00000000..4a206f33
--- /dev/null
+++ b/userapi/producers/syncapi.go
@@ -0,0 +1,104 @@
+package producers
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal/eventutil"
+ "github.com/matrix-org/dendrite/setup/jetstream"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/nats-io/nats.go"
+ log "github.com/sirupsen/logrus"
+)
+
+type JetStreamPublisher interface {
+ PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error)
+}
+
+// SyncAPI produces messages for the Sync API server to consume.
+type SyncAPI struct {
+ db storage.Database
+ producer JetStreamPublisher
+ clientDataTopic string
+ notificationDataTopic string
+}
+
+func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
+ return &SyncAPI{
+ db: db,
+ producer: js,
+ clientDataTopic: clientDataTopic,
+ notificationDataTopic: notificationDataTopic,
+ }
+}
+
+// SendAccountData sends account data to the Sync API server.
+func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error {
+ m := &nats.Msg{
+ Subject: p.clientDataTopic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.UserID, userID)
+
+ var err error
+ m.Data, err = json.Marshal(eventutil.AccountData{
+ RoomID: roomID,
+ Type: dataType,
+ })
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "user_id": userID,
+ "room_id": roomID,
+ "data_type": dataType,
+ }).Tracef("Producing to topic '%s'", p.clientDataTopic)
+
+ _, err = p.producer.PublishMsg(m)
+ return err
+}
+
+// GetAndSendNotificationData reads the database and sends data about unread
+// notifications to the Sync API server.
+func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
+ localpart, _, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+
+ ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
+ if err != nil {
+ return err
+ }
+
+ return p.sendNotificationData(userID, &eventutil.NotificationData{
+ RoomID: roomID,
+ UnreadHighlightCount: int(nhighlight),
+ UnreadNotificationCount: int(ntotal),
+ })
+}
+
+// sendNotificationData sends data about unread notifications to the Sync API server.
+func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error {
+ m := &nats.Msg{
+ Subject: p.notificationDataTopic,
+ Header: nats.Header{},
+ }
+ m.Header.Set(jetstream.UserID, userID)
+
+ var err error
+ m.Data, err = json.Marshal(data)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "user_id": userID,
+ "room_id": data.RoomID,
+ }).Tracef("Producing to topic '%s'", p.clientDataTopic)
+
+ _, err = p.producer.PublishMsg(m)
+ return err
+}
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index a131dac4..6d22fea9 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
type Database interface {
@@ -89,6 +90,18 @@ type Database interface {
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
+
+ InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
+ DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
+ SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
+ GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
+ GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
+ GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
+
+ UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
+ GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
+ RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
+ RemovePushers(ctx context.Context, appid, pushkey string) error
}
// Err3PIDInUse is the error returned when trying to save an association involving
diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go
new file mode 100644
index 00000000..7bcc0f9c
--- /dev/null
+++ b/userapi/storage/postgres/notifications_table.go
@@ -0,0 +1,219 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+type notificationsStatements struct {
+ insertStmt *sql.Stmt
+ deleteUpToStmt *sql.Stmt
+ updateReadStmt *sql.Stmt
+ selectStmt *sql.Stmt
+ selectCountStmt *sql.Stmt
+ selectRoomCountsStmt *sql.Stmt
+}
+
+const notificationSchema = `
+CREATE TABLE IF NOT EXISTS userapi_notifications (
+ id BIGSERIAL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ stream_pos BIGINT NOT NULL,
+ ts_ms BIGINT NOT NULL,
+ highlight BOOLEAN NOT NULL,
+ notification_json TEXT NOT NULL,
+ read BOOLEAN NOT NULL DEFAULT FALSE
+);
+
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+`
+
+const insertNotificationSQL = "" +
+ "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+
+const deleteNotificationsUpToSQL = "" +
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+
+const updateNotificationReadSQL = "" +
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+
+const selectNotificationSQL = "" +
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $4"
+
+const selectNotificationCountSQL = "" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
+ "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read"
+
+const selectRoomNotificationCountsSQL = "" +
+ "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
+ "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+
+func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
+ s := &notificationsStatements{}
+ _, err := db.Exec(notificationSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertStmt, insertNotificationSQL},
+ {&s.deleteUpToStmt, deleteNotificationsUpToSQL},
+ {&s.updateReadStmt, updateNotificationReadSQL},
+ {&s.selectStmt, selectNotificationSQL},
+ {&s.selectCountStmt, selectNotificationCountSQL},
+ {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
+ }.Prepare(db)
+}
+
+// Insert inserts a notification into the database.
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
+ roomID, tsMS := n.RoomID, n.TS
+ nn := *n
+ // Clears out fields that have their own columns to (1) shrink the
+ // data and (2) avoid difficult-to-debug inconsistency bugs.
+ nn.RoomID = ""
+ nn.TS, nn.Read = 0, false
+ bs, err := json.Marshal(nn)
+ if err != nil {
+ return err
+ }
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ return err
+}
+
+// DeleteUpTo deletes all previous notifications, up to and including the event.
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+// UpdateRead updates the "read" value for an event.
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ var maxID int64 = -1
+ var notifs []*api.Notification
+ for rows.Next() {
+ var id int64
+ var roomID string
+ var ts gomatrixserverlib.Timestamp
+ var read bool
+ var jsonStr string
+ err = rows.Scan(
+ &id,
+ &roomID,
+ &ts,
+ &read,
+ &jsonStr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ var n api.Notification
+ err := json.Unmarshal([]byte(jsonStr), &n)
+ if err != nil {
+ return nil, 0, err
+ }
+ n.RoomID = roomID
+ n.TS = ts
+ n.Read = read
+ notifs = append(notifs, &n)
+
+ if maxID < id {
+ maxID = id
+ }
+ }
+ return notifs, maxID, rows.Err()
+}
+
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
+
+ if err != nil {
+ return 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var count int64
+ if err := rows.Scan(&count); err != nil {
+ return 0, err
+ }
+
+ return count, nil
+ }
+ return 0, rows.Err()
+}
+
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
+
+ if err != nil {
+ return 0, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var total, highlight int64
+ if err := rows.Scan(&total, &highlight); err != nil {
+ return 0, 0, err
+ }
+
+ return total, highlight, nil
+ }
+ return 0, 0, rows.Err()
+}
diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go
new file mode 100644
index 00000000..670dc916
--- /dev/null
+++ b/userapi/storage/postgres/pusher_table.go
@@ -0,0 +1,157 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
+const pushersSchema = `
+CREATE TABLE IF NOT EXISTS userapi_pushers (
+ id BIGSERIAL PRIMARY KEY,
+ -- The Matrix user ID localpart for this pusher
+ localpart TEXT NOT NULL,
+ session_id BIGINT DEFAULT NULL,
+ profile_tag TEXT,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
+ lang TEXT NOT NULL,
+ data TEXT NOT NULL
+);
+
+-- For faster deleting by app_id, pushkey pair.
+CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
+
+-- For faster retrieving by localpart.
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+
+-- Pushkey must be unique for a given user and app.
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+`
+
+const insertPusherSQL = "" +
+ "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
+ "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+
+const selectPushersSQL = "" +
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+
+const deletePusherSQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+
+const deletePushersByAppIdAndPushKeySQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
+
+func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) {
+ s := &pushersStatements{}
+ _, err := db.Exec(pushersSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertPusherStmt, insertPusherSQL},
+ {&s.selectPushersStmt, selectPushersSQL},
+ {&s.deletePusherStmt, deletePusherSQL},
+ {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
+ }.Prepare(db)
+}
+
+type pushersStatements struct {
+ insertPusherStmt *sql.Stmt
+ selectPushersStmt *sql.Stmt
+ deletePusherStmt *sql.Stmt
+ deletePushersByAppIdAndPushKeyStmt *sql.Stmt
+}
+
+// insertPusher creates a new pusher.
+// Returns an error if the user already has a pusher with the given pusher pushkey.
+// Returns nil error success.
+func (s *pushersStatements) InsertPusher(
+ ctx context.Context, txn *sql.Tx, session_id int64,
+ pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
+ logrus.Debugf("Created pusher %d", session_id)
+ return err
+}
+
+func (s *pushersStatements) SelectPushers(
+ ctx context.Context, txn *sql.Tx, localpart string,
+) ([]api.Pusher, error) {
+ pushers := []api.Pusher{}
+ rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
+
+ if err != nil {
+ return pushers, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
+
+ for rows.Next() {
+ var pusher api.Pusher
+ var data []byte
+ err = rows.Scan(
+ &pusher.SessionID,
+ &pusher.PushKey,
+ &pusher.PushKeyTS,
+ &pusher.Kind,
+ &pusher.AppID,
+ &pusher.AppDisplayName,
+ &pusher.DeviceDisplayName,
+ &pusher.ProfileTag,
+ &pusher.Language,
+ &data)
+ if err != nil {
+ return pushers, err
+ }
+ err := json.Unmarshal(data, &pusher.Data)
+ if err != nil {
+ return pushers, err
+ }
+ pushers = append(pushers, pusher)
+ }
+
+ logrus.Debugf("Database returned %d pushers", len(pushers))
+ return pushers, rows.Err()
+}
+
+// deletePusher removes a single pusher by pushkey and user localpart.
+func (s *pushersStatements) DeletePusher(
+ ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
+ return err
+}
+
+func (s *pushersStatements) DeletePushers(
+ ctx context.Context, txn *sql.Tx, appid, pushkey string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
+ return err
+}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index ac5c59b8..c74a999f 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
+ pusherTable, err := NewPostgresPusherTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
+ }
+ notificationsTable, err := NewPostgresNotificationTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
+ }
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
+ Pushers: pusherTable,
+ Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 5f1f9500..a58974b4 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -29,6 +29,7 @@ import (
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
@@ -47,6 +48,8 @@ type Database struct {
KeyBackupVersions tables.KeyBackupVersionTable
Devices tables.DevicesTable
LoginTokens tables.LoginTokenTable
+ Notifications tables.NotificationTable
+ Pushers tables.PusherTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
BcryptCost int
@@ -160,15 +163,12 @@ func (d *Database) createAccount(
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
- if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
- "global": {
- "content": [],
- "override": [],
- "room": [],
- "sender": [],
- "underride": []
- }
- }`)); err != nil {
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
+ prbs, err := json.Marshal(pushRuleSets)
+ if err != nil {
+ return nil, err
+ }
+ if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err
}
return account, nil
@@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token)
}
+
+func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
+ })
+}
+
+func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
+ return err
+ })
+ return
+}
+
+func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
+ return err
+ })
+ return
+}
+
+func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
+}
+
+func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
+ return d.Notifications.SelectCount(ctx, nil, localpart, filter)
+}
+
+func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
+ return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
+}
+
+func (d *Database) UpsertPusher(
+ ctx context.Context, p api.Pusher, localpart string,
+) error {
+ data, err := json.Marshal(p.Data)
+ if err != nil {
+ return err
+ }
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Pushers.InsertPusher(
+ ctx, txn,
+ p.SessionID,
+ p.PushKey,
+ p.PushKeyTS,
+ p.Kind,
+ p.AppID,
+ p.AppDisplayName,
+ p.DeviceDisplayName,
+ p.ProfileTag,
+ p.Language,
+ string(data),
+ localpart)
+ })
+}
+
+// GetPushers returns the pushers matching the given localpart.
+func (d *Database) GetPushers(
+ ctx context.Context, localpart string,
+) ([]api.Pusher, error) {
+ return d.Pushers.SelectPushers(ctx, nil, localpart)
+}
+
+// RemovePusher deletes one pusher
+// Invoked when `append` is true and `kind` is null in
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
+func (d *Database) RemovePusher(
+ ctx context.Context, appid, pushkey, localpart string,
+) error {
+ return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
+ err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
+ if err == sql.ErrNoRows {
+ return nil
+ }
+ return err
+ })
+}
+
+// RemovePushers deletes all pushers that match given App Id and Push Key pair.
+// Invoked when `append` parameter is false in
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
+func (d *Database) RemovePushers(
+ ctx context.Context, appid, pushkey string,
+) error {
+ return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
+ return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
+ })
+}
diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go
new file mode 100644
index 00000000..fcfb1aad
--- /dev/null
+++ b/userapi/storage/sqlite3/notifications_table.go
@@ -0,0 +1,219 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+type notificationsStatements struct {
+ insertStmt *sql.Stmt
+ deleteUpToStmt *sql.Stmt
+ updateReadStmt *sql.Stmt
+ selectStmt *sql.Stmt
+ selectCountStmt *sql.Stmt
+ selectRoomCountsStmt *sql.Stmt
+}
+
+const notificationSchema = `
+CREATE TABLE IF NOT EXISTS userapi_notifications (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ localpart TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ stream_pos BIGINT NOT NULL,
+ ts_ms BIGINT NOT NULL,
+ highlight BOOLEAN NOT NULL,
+ notification_json TEXT NOT NULL,
+ read BOOLEAN NOT NULL DEFAULT FALSE
+);
+
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+`
+
+const insertNotificationSQL = "" +
+ "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+
+const deleteNotificationsUpToSQL = "" +
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+
+const updateNotificationReadSQL = "" +
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+
+const selectNotificationSQL = "" +
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $4"
+
+const selectNotificationCountSQL = "" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
+ "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read"
+
+const selectRoomNotificationCountsSQL = "" +
+ "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
+ "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+
+func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
+ s := &notificationsStatements{}
+ _, err := db.Exec(notificationSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertStmt, insertNotificationSQL},
+ {&s.deleteUpToStmt, deleteNotificationsUpToSQL},
+ {&s.updateReadStmt, updateNotificationReadSQL},
+ {&s.selectStmt, selectNotificationSQL},
+ {&s.selectCountStmt, selectNotificationCountSQL},
+ {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
+ }.Prepare(db)
+}
+
+// Insert inserts a notification into the database.
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
+ roomID, tsMS := n.RoomID, n.TS
+ nn := *n
+ // Clears out fields that have their own columns to (1) shrink the
+ // data and (2) avoid difficult-to-debug inconsistency bugs.
+ nn.RoomID = ""
+ nn.TS, nn.Read = 0, false
+ bs, err := json.Marshal(nn)
+ if err != nil {
+ return err
+ }
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ return err
+}
+
+// DeleteUpTo deletes all previous notifications, up to and including the event.
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+// UpdateRead updates the "read" value for an event.
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ var maxID int64 = -1
+ var notifs []*api.Notification
+ for rows.Next() {
+ var id int64
+ var roomID string
+ var ts gomatrixserverlib.Timestamp
+ var read bool
+ var jsonStr string
+ err = rows.Scan(
+ &id,
+ &roomID,
+ &ts,
+ &read,
+ &jsonStr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ var n api.Notification
+ err := json.Unmarshal([]byte(jsonStr), &n)
+ if err != nil {
+ return nil, 0, err
+ }
+ n.RoomID = roomID
+ n.TS = ts
+ n.Read = read
+ notifs = append(notifs, &n)
+
+ if maxID < id {
+ maxID = id
+ }
+ }
+ return notifs, maxID, rows.Err()
+}
+
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
+
+ if err != nil {
+ return 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var count int64
+ if err := rows.Scan(&count); err != nil {
+ return 0, err
+ }
+
+ return count, nil
+ }
+ return 0, rows.Err()
+}
+
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
+
+ if err != nil {
+ return 0, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var total, highlight int64
+ if err := rows.Scan(&total, &highlight); err != nil {
+ return 0, 0, err
+ }
+
+ return total, highlight, nil
+ }
+ return 0, 0, rows.Err()
+}
diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go
new file mode 100644
index 00000000..e718792e
--- /dev/null
+++ b/userapi/storage/sqlite3/pusher_table.go
@@ -0,0 +1,157 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
+const pushersSchema = `
+CREATE TABLE IF NOT EXISTS userapi_pushers (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The Matrix user ID localpart for this pusher
+ localpart TEXT NOT NULL,
+ session_id BIGINT DEFAULT NULL,
+ profile_tag TEXT,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
+ lang TEXT NOT NULL,
+ data TEXT NOT NULL
+);
+
+-- For faster deleting by app_id, pushkey pair.
+CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
+
+-- For faster retrieving by localpart.
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+
+-- Pushkey must be unique for a given user and app.
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+`
+
+const insertPusherSQL = "" +
+ "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
+ "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+
+const selectPushersSQL = "" +
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+
+const deletePusherSQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+
+const deletePushersByAppIdAndPushKeySQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
+
+func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) {
+ s := &pushersStatements{}
+ _, err := db.Exec(pushersSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertPusherStmt, insertPusherSQL},
+ {&s.selectPushersStmt, selectPushersSQL},
+ {&s.deletePusherStmt, deletePusherSQL},
+ {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
+ }.Prepare(db)
+}
+
+type pushersStatements struct {
+ insertPusherStmt *sql.Stmt
+ selectPushersStmt *sql.Stmt
+ deletePusherStmt *sql.Stmt
+ deletePushersByAppIdAndPushKeyStmt *sql.Stmt
+}
+
+// insertPusher creates a new pusher.
+// Returns an error if the user already has a pusher with the given pusher pushkey.
+// Returns nil error success.
+func (s *pushersStatements) InsertPusher(
+ ctx context.Context, txn *sql.Tx, session_id int64,
+ pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+) error {
+ _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
+ logrus.Debugf("Created pusher %d", session_id)
+ return err
+}
+
+func (s *pushersStatements) SelectPushers(
+ ctx context.Context, txn *sql.Tx, localpart string,
+) ([]api.Pusher, error) {
+ pushers := []api.Pusher{}
+ rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
+
+ if err != nil {
+ return pushers, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
+
+ for rows.Next() {
+ var pusher api.Pusher
+ var data []byte
+ err = rows.Scan(
+ &pusher.SessionID,
+ &pusher.PushKey,
+ &pusher.PushKeyTS,
+ &pusher.Kind,
+ &pusher.AppID,
+ &pusher.AppDisplayName,
+ &pusher.DeviceDisplayName,
+ &pusher.ProfileTag,
+ &pusher.Language,
+ &data)
+ if err != nil {
+ return pushers, err
+ }
+ err := json.Unmarshal(data, &pusher.Data)
+ if err != nil {
+ return pushers, err
+ }
+ pushers = append(pushers, pusher)
+ }
+
+ logrus.Debugf("Database returned %d pushers", len(pushers))
+ return pushers, rows.Err()
+}
+
+// deletePusher removes a single pusher by pushkey and user localpart.
+func (s *pushersStatements) DeletePusher(
+ ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+) error {
+ _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart)
+ return err
+}
+
+func (s *pushersStatements) DeletePushers(
+ ctx context.Context, txn *sql.Tx, appid, pushkey string,
+) error {
+ _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey)
+ return err
+}
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 98c24497..b5bb96c4 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
}
+ pusherTable, err := NewSQLitePusherTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
+ }
+ notificationsTable, err := NewSQLiteNotificationTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
+ }
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
+ Pushers: pusherTable,
+ Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 12939ced..815e5119 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
)
type AccountDataTable interface {
@@ -93,3 +94,42 @@ type ThreePIDTable interface {
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}
+
+type PusherTable interface {
+ InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
+ SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
+ DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
+ DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
+}
+
+type NotificationTable interface {
+ Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error
+ DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error)
+ UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error)
+ Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
+ SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
+ SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
+}
+
+type NotificationFilter uint32
+
+const (
+ // HighlightNotifications returns notifications that had a
+ // "highlight" tweak assigned to them from evaluating push rules.
+ HighlightNotifications NotificationFilter = 1 << iota
+
+ // NonHighlightNotifications returns notifications that don't
+ // match HighlightNotifications.
+ NonHighlightNotifications
+
+ // NoNotifications is a filter to exclude all types of
+ // notifications. It's useful as a zero value, but isn't likely to
+ // be used in a call to Notifications.Select*.
+ NoNotifications NotificationFilter = 0
+
+ // AllNotifications is a filter to include all types of
+ // notifications in Notifications.Select*. Note that PostgreSQL
+ // balks if this doesn't fit in INTEGER, even though we use
+ // uint32.
+ AllNotifications NotificationFilter = (1 << 31) - 1
+)
diff --git a/userapi/userapi.go b/userapi/userapi.go
index 4a5793ab..2382e951 100644
--- a/userapi/userapi.go
+++ b/userapi/userapi.go
@@ -18,11 +18,17 @@ import (
"time"
"github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/internal/pushgateway"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
+ rsapi "github.com/matrix-org/dendrite/roomserver/api"
+ "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/consumers"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
+ "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus"
)
@@ -36,26 +42,49 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
- accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
+ base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI,
+ appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
+ rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client,
) api.UserInternalAPI {
db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db")
}
- return newInternalAPI(db, cfg, appServices, keyAPI)
-}
+ js := jetstream.Prepare(&cfg.Matrix.JetStream)
-func newInternalAPI(
- db storage.Database,
- cfg *config.UserAPI,
- appServices []config.ApplicationService,
- keyAPI keyapi.KeyInternalAPI,
-) api.UserInternalAPI {
- return &internal.UserInternalAPI{
- DB: db,
- ServerName: cfg.Matrix.ServerName,
- AppServices: appServices,
- KeyAPI: keyAPI,
+ syncProducer := producers.NewSyncAPI(
+ db, js,
+ // TODO: user API should handle syncs for account data. Right now,
+ // it's handled by clientapi, and hence uses its topic. When user
+ // API handles it for all account data, we can remove it from
+ // here.
+ cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
+ cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
+ )
+
+ userAPI := &internal.UserInternalAPI{
+ DB: db,
+ SyncProducer: syncProducer,
+ ServerName: cfg.Matrix.ServerName,
+ AppServices: appServices,
+ KeyAPI: keyAPI,
+ DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
+ }
+
+ readConsumer := consumers.NewOutputReadUpdateConsumer(
+ base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer,
+ )
+ if err := readConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start user API read update consumer")
+ }
+
+ eventConsumer := consumers.NewOutputStreamEventConsumer(
+ base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer,
+ )
+ if err := eventConsumer.Start(); err != nil {
+ logrus.WithError(err).Panic("failed to start user API streamed event consumer")
}
+
+ return userAPI
}
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 4214c07f..25319c4b 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage"
)
@@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
},
}
- return newInternalAPI(accountDB, cfg, nil, nil), accountDB
+ return &internal.UserInternalAPI{
+ DB: accountDB,
+ ServerName: cfg.Matrix.ServerName,
+ }, accountDB
}
func TestQueryProfile(t *testing.T) {
diff --git a/userapi/util/devices.go b/userapi/util/devices.go
new file mode 100644
index 00000000..cbf3bd28
--- /dev/null
+++ b/userapi/util/devices.go
@@ -0,0 +1,100 @@
+package util
+
+import (
+ "context"
+
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ log "github.com/sirupsen/logrus"
+)
+
+type PusherDevice struct {
+ Device pushgateway.Device
+ Pusher *api.Pusher
+ URL string
+ Format string
+}
+
+// GetPushDevices pushes to the configured devices of a local user.
+func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
+ pushers, err := db.GetPushers(ctx, localpart)
+ if err != nil {
+ return nil, err
+ }
+
+ devices := make([]*PusherDevice, 0, len(pushers))
+ for _, pusher := range pushers {
+ var url, format string
+ data := pusher.Data
+ switch pusher.Kind {
+ case api.EmailKind:
+ url = "mailto:"
+
+ case api.HTTPKind:
+ // TODO: The spec says only event_id_only is supported,
+ // but Sytests assume "" means "full notification".
+ fmtIface := pusher.Data["format"]
+ var ok bool
+ format, ok = fmtIface.(string)
+ if ok && format != "event_id_only" {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id": pusher.AppID,
+ }).Errorf("Only data.format event_id_only or empty is supported")
+ continue
+ }
+
+ urlIface := pusher.Data["url"]
+ url, ok = urlIface.(string)
+ if !ok {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id": pusher.AppID,
+ }).Errorf("No data.url configured for HTTP Pusher")
+ continue
+ }
+ data = mapWithout(data, "url")
+
+ default:
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id": pusher.AppID,
+ "kind": pusher.Kind,
+ }).Errorf("Unhandled pusher kind")
+ continue
+ }
+
+ devices = append(devices, &PusherDevice{
+ Device: pushgateway.Device{
+ AppID: pusher.AppID,
+ Data: data,
+ PushKey: pusher.PushKey,
+ PushKeyTS: pusher.PushKeyTS,
+ Tweaks: tweaks,
+ },
+ Pusher: &pusher,
+ URL: url,
+ Format: format,
+ })
+ }
+
+ return devices, nil
+}
+
+// mapWithout returns a shallow copy of the map, without the given
+// key. Returns nil if the resulting map is empty.
+func mapWithout(m map[string]interface{}, key string) map[string]interface{} {
+ ret := make(map[string]interface{}, len(m))
+ for k, v := range m {
+ // The specification says we do not send "url".
+ if k == key {
+ continue
+ }
+ ret[k] = v
+ }
+ if len(ret) == 0 {
+ return nil
+ }
+ return ret
+}
diff --git a/userapi/util/notify.go b/userapi/util/notify.go
new file mode 100644
index 00000000..ff206bd3
--- /dev/null
+++ b/userapi/util/notify.go
@@ -0,0 +1,76 @@
+package util
+
+import (
+ "context"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/pushgateway"
+ "github.com/matrix-org/dendrite/userapi/storage"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ log "github.com/sirupsen/logrus"
+)
+
+// NotifyUserCountsAsync sends notifications to a local user's
+// notification destinations. Database lookups run synchronously, but
+// a single goroutine is started when talking to the Push
+// gateways. There is no way to know when the background goroutine has
+// finished.
+func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
+ pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
+ if err != nil {
+ return err
+ }
+
+ if len(pusherDevices) == 0 {
+ return nil
+ }
+
+ userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
+ if err != nil {
+ return err
+ }
+
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id0": pusherDevices[0].Device.AppID,
+ "pushkey": pusherDevices[0].Device.PushKey,
+ }).Tracef("Notifying HTTP push gateway about notification counts")
+
+ // TODO: think about bounding this to one per user, and what
+ // ordering guarantees we must provide.
+ go func() {
+ // This background processing cannot be tied to a request.
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ // TODO: we could batch all devices with the same URL, but
+ // Sytest requires consumers/roomserver.go to do it
+ // one-by-one, so we do the same here.
+ for _, pusherDevice := range pusherDevices {
+ // TODO: support "email".
+ if !strings.HasPrefix(pusherDevice.URL, "http") {
+ continue
+ }
+
+ req := pushgateway.NotifyRequest{
+ Notification: pushgateway.Notification{
+ Counts: &pushgateway.Counts{
+ Unread: int(userNumUnreadNotifs),
+ },
+ Devices: []*pushgateway.Device{&pusherDevice.Device},
+ },
+ }
+ if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil {
+ log.WithFields(log.Fields{
+ "localpart": localpart,
+ "app_id0": pusherDevice.Device.AppID,
+ "pushkey": pusherDevice.Device.PushKey,
+ }).WithError(err).Error("HTTP push gateway request failed")
+ return
+ }
+ }
+ }()
+
+ return nil
+}