aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage')
-rw-r--r--userapi/storage/interface.go13
-rw-r--r--userapi/storage/postgres/notifications_table.go219
-rw-r--r--userapi/storage/postgres/pusher_table.go157
-rw-r--r--userapi/storage/postgres/storage.go10
-rw-r--r--userapi/storage/shared/storage.go109
-rw-r--r--userapi/storage/sqlite3/notifications_table.go219
-rw-r--r--userapi/storage/sqlite3/pusher_table.go157
-rw-r--r--userapi/storage/sqlite3/storage.go10
-rw-r--r--userapi/storage/tables/interface.go40
9 files changed, 925 insertions, 9 deletions
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index a131dac4..6d22fea9 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
)
type Database interface {
@@ -89,6 +90,18 @@ type Database interface {
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
+
+ InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
+ DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
+ SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
+ GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
+ GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
+ GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
+
+ UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
+ GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
+ RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
+ RemovePushers(ctx context.Context, appid, pushkey string) error
}
// Err3PIDInUse is the error returned when trying to save an association involving
diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go
new file mode 100644
index 00000000..7bcc0f9c
--- /dev/null
+++ b/userapi/storage/postgres/notifications_table.go
@@ -0,0 +1,219 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+type notificationsStatements struct {
+ insertStmt *sql.Stmt
+ deleteUpToStmt *sql.Stmt
+ updateReadStmt *sql.Stmt
+ selectStmt *sql.Stmt
+ selectCountStmt *sql.Stmt
+ selectRoomCountsStmt *sql.Stmt
+}
+
+const notificationSchema = `
+CREATE TABLE IF NOT EXISTS userapi_notifications (
+ id BIGSERIAL PRIMARY KEY,
+ localpart TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ stream_pos BIGINT NOT NULL,
+ ts_ms BIGINT NOT NULL,
+ highlight BOOLEAN NOT NULL,
+ notification_json TEXT NOT NULL,
+ read BOOLEAN NOT NULL DEFAULT FALSE
+);
+
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+`
+
+const insertNotificationSQL = "" +
+ "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+
+const deleteNotificationsUpToSQL = "" +
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+
+const updateNotificationReadSQL = "" +
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+
+const selectNotificationSQL = "" +
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $4"
+
+const selectNotificationCountSQL = "" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
+ "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read"
+
+const selectRoomNotificationCountsSQL = "" +
+ "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
+ "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+
+func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
+ s := &notificationsStatements{}
+ _, err := db.Exec(notificationSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertStmt, insertNotificationSQL},
+ {&s.deleteUpToStmt, deleteNotificationsUpToSQL},
+ {&s.updateReadStmt, updateNotificationReadSQL},
+ {&s.selectStmt, selectNotificationSQL},
+ {&s.selectCountStmt, selectNotificationCountSQL},
+ {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
+ }.Prepare(db)
+}
+
+// Insert inserts a notification into the database.
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
+ roomID, tsMS := n.RoomID, n.TS
+ nn := *n
+ // Clears out fields that have their own columns to (1) shrink the
+ // data and (2) avoid difficult-to-debug inconsistency bugs.
+ nn.RoomID = ""
+ nn.TS, nn.Read = 0, false
+ bs, err := json.Marshal(nn)
+ if err != nil {
+ return err
+ }
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ return err
+}
+
+// DeleteUpTo deletes all previous notifications, up to and including the event.
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+// UpdateRead updates the "read" value for an event.
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ var maxID int64 = -1
+ var notifs []*api.Notification
+ for rows.Next() {
+ var id int64
+ var roomID string
+ var ts gomatrixserverlib.Timestamp
+ var read bool
+ var jsonStr string
+ err = rows.Scan(
+ &id,
+ &roomID,
+ &ts,
+ &read,
+ &jsonStr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ var n api.Notification
+ err := json.Unmarshal([]byte(jsonStr), &n)
+ if err != nil {
+ return nil, 0, err
+ }
+ n.RoomID = roomID
+ n.TS = ts
+ n.Read = read
+ notifs = append(notifs, &n)
+
+ if maxID < id {
+ maxID = id
+ }
+ }
+ return notifs, maxID, rows.Err()
+}
+
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
+
+ if err != nil {
+ return 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var count int64
+ if err := rows.Scan(&count); err != nil {
+ return 0, err
+ }
+
+ return count, nil
+ }
+ return 0, rows.Err()
+}
+
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
+
+ if err != nil {
+ return 0, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var total, highlight int64
+ if err := rows.Scan(&total, &highlight); err != nil {
+ return 0, 0, err
+ }
+
+ return total, highlight, nil
+ }
+ return 0, 0, rows.Err()
+}
diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go
new file mode 100644
index 00000000..670dc916
--- /dev/null
+++ b/userapi/storage/postgres/pusher_table.go
@@ -0,0 +1,157 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
+const pushersSchema = `
+CREATE TABLE IF NOT EXISTS userapi_pushers (
+ id BIGSERIAL PRIMARY KEY,
+ -- The Matrix user ID localpart for this pusher
+ localpart TEXT NOT NULL,
+ session_id BIGINT DEFAULT NULL,
+ profile_tag TEXT,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
+ lang TEXT NOT NULL,
+ data TEXT NOT NULL
+);
+
+-- For faster deleting by app_id, pushkey pair.
+CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
+
+-- For faster retrieving by localpart.
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+
+-- Pushkey must be unique for a given user and app.
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+`
+
+const insertPusherSQL = "" +
+ "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
+ "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+
+const selectPushersSQL = "" +
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+
+const deletePusherSQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+
+const deletePushersByAppIdAndPushKeySQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
+
+func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) {
+ s := &pushersStatements{}
+ _, err := db.Exec(pushersSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertPusherStmt, insertPusherSQL},
+ {&s.selectPushersStmt, selectPushersSQL},
+ {&s.deletePusherStmt, deletePusherSQL},
+ {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
+ }.Prepare(db)
+}
+
+type pushersStatements struct {
+ insertPusherStmt *sql.Stmt
+ selectPushersStmt *sql.Stmt
+ deletePusherStmt *sql.Stmt
+ deletePushersByAppIdAndPushKeyStmt *sql.Stmt
+}
+
+// insertPusher creates a new pusher.
+// Returns an error if the user already has a pusher with the given pusher pushkey.
+// Returns nil error success.
+func (s *pushersStatements) InsertPusher(
+ ctx context.Context, txn *sql.Tx, session_id int64,
+ pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
+ logrus.Debugf("Created pusher %d", session_id)
+ return err
+}
+
+func (s *pushersStatements) SelectPushers(
+ ctx context.Context, txn *sql.Tx, localpart string,
+) ([]api.Pusher, error) {
+ pushers := []api.Pusher{}
+ rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
+
+ if err != nil {
+ return pushers, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
+
+ for rows.Next() {
+ var pusher api.Pusher
+ var data []byte
+ err = rows.Scan(
+ &pusher.SessionID,
+ &pusher.PushKey,
+ &pusher.PushKeyTS,
+ &pusher.Kind,
+ &pusher.AppID,
+ &pusher.AppDisplayName,
+ &pusher.DeviceDisplayName,
+ &pusher.ProfileTag,
+ &pusher.Language,
+ &data)
+ if err != nil {
+ return pushers, err
+ }
+ err := json.Unmarshal(data, &pusher.Data)
+ if err != nil {
+ return pushers, err
+ }
+ pushers = append(pushers, pusher)
+ }
+
+ logrus.Debugf("Database returned %d pushers", len(pushers))
+ return pushers, rows.Err()
+}
+
+// deletePusher removes a single pusher by pushkey and user localpart.
+func (s *pushersStatements) DeletePusher(
+ ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
+ return err
+}
+
+func (s *pushersStatements) DeletePushers(
+ ctx context.Context, txn *sql.Tx, appid, pushkey string,
+) error {
+ _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
+ return err
+}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index ac5c59b8..c74a999f 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
+ pusherTable, err := NewPostgresPusherTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
+ }
+ notificationsTable, err := NewPostgresNotificationTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
+ }
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
+ Pushers: pusherTable,
+ Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 5f1f9500..a58974b4 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -29,6 +29,7 @@ import (
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
@@ -47,6 +48,8 @@ type Database struct {
KeyBackupVersions tables.KeyBackupVersionTable
Devices tables.DevicesTable
LoginTokens tables.LoginTokenTable
+ Notifications tables.NotificationTable
+ Pushers tables.PusherTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
BcryptCost int
@@ -160,15 +163,12 @@ func (d *Database) createAccount(
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
- if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
- "global": {
- "content": [],
- "override": [],
- "room": [],
- "sender": [],
- "underride": []
- }
- }`)); err != nil {
+ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
+ prbs, err := json.Marshal(pushRuleSets)
+ if err != nil {
+ return nil, err
+ }
+ if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err
}
return account, nil
@@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token)
}
+
+func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
+ })
+}
+
+func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
+ return err
+ })
+ return
+}
+
+func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
+ return err
+ })
+ return
+}
+
+func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
+}
+
+func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
+ return d.Notifications.SelectCount(ctx, nil, localpart, filter)
+}
+
+func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
+ return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
+}
+
+func (d *Database) UpsertPusher(
+ ctx context.Context, p api.Pusher, localpart string,
+) error {
+ data, err := json.Marshal(p.Data)
+ if err != nil {
+ return err
+ }
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.Pushers.InsertPusher(
+ ctx, txn,
+ p.SessionID,
+ p.PushKey,
+ p.PushKeyTS,
+ p.Kind,
+ p.AppID,
+ p.AppDisplayName,
+ p.DeviceDisplayName,
+ p.ProfileTag,
+ p.Language,
+ string(data),
+ localpart)
+ })
+}
+
+// GetPushers returns the pushers matching the given localpart.
+func (d *Database) GetPushers(
+ ctx context.Context, localpart string,
+) ([]api.Pusher, error) {
+ return d.Pushers.SelectPushers(ctx, nil, localpart)
+}
+
+// RemovePusher deletes one pusher
+// Invoked when `append` is true and `kind` is null in
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
+func (d *Database) RemovePusher(
+ ctx context.Context, appid, pushkey, localpart string,
+) error {
+ return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
+ err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
+ if err == sql.ErrNoRows {
+ return nil
+ }
+ return err
+ })
+}
+
+// RemovePushers deletes all pushers that match given App Id and Push Key pair.
+// Invoked when `append` parameter is false in
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
+func (d *Database) RemovePushers(
+ ctx context.Context, appid, pushkey string,
+) error {
+ return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
+ return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
+ })
+}
diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go
new file mode 100644
index 00000000..fcfb1aad
--- /dev/null
+++ b/userapi/storage/sqlite3/notifications_table.go
@@ -0,0 +1,219 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ log "github.com/sirupsen/logrus"
+)
+
+type notificationsStatements struct {
+ insertStmt *sql.Stmt
+ deleteUpToStmt *sql.Stmt
+ updateReadStmt *sql.Stmt
+ selectStmt *sql.Stmt
+ selectCountStmt *sql.Stmt
+ selectRoomCountsStmt *sql.Stmt
+}
+
+const notificationSchema = `
+CREATE TABLE IF NOT EXISTS userapi_notifications (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ localpart TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ stream_pos BIGINT NOT NULL,
+ ts_ms BIGINT NOT NULL,
+ highlight BOOLEAN NOT NULL,
+ notification_json TEXT NOT NULL,
+ read BOOLEAN NOT NULL DEFAULT FALSE
+);
+
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
+CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
+`
+
+const insertNotificationSQL = "" +
+ "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
+
+const deleteNotificationsUpToSQL = "" +
+ "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
+
+const updateNotificationReadSQL = "" +
+ "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
+
+const selectNotificationSQL = "" +
+ "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
+ "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read ORDER BY localpart, id LIMIT $4"
+
+const selectNotificationCountSQL = "" +
+ "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
+ "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
+ ") AND NOT read"
+
+const selectRoomNotificationCountsSQL = "" +
+ "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
+ "WHERE localpart = $1 AND room_id = $2 AND NOT read"
+
+func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
+ s := &notificationsStatements{}
+ _, err := db.Exec(notificationSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertStmt, insertNotificationSQL},
+ {&s.deleteUpToStmt, deleteNotificationsUpToSQL},
+ {&s.updateReadStmt, updateNotificationReadSQL},
+ {&s.selectStmt, selectNotificationSQL},
+ {&s.selectCountStmt, selectNotificationCountSQL},
+ {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
+ }.Prepare(db)
+}
+
+// Insert inserts a notification into the database.
+func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
+ roomID, tsMS := n.RoomID, n.TS
+ nn := *n
+ // Clears out fields that have their own columns to (1) shrink the
+ // data and (2) avoid difficult-to-debug inconsistency bugs.
+ nn.RoomID = ""
+ nn.TS, nn.Read = 0, false
+ bs, err := json.Marshal(nn)
+ if err != nil {
+ return err
+ }
+ _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
+ return err
+}
+
+// DeleteUpTo deletes all previous notifications, up to and including the event.
+func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+// UpdateRead updates the "read" value for an event.
+func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
+ res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
+ if err != nil {
+ return false, err
+ }
+ nrows, err := res.RowsAffected()
+ if err != nil {
+ return true, err
+ }
+ log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
+ return nrows > 0, nil
+}
+
+func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
+
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ var maxID int64 = -1
+ var notifs []*api.Notification
+ for rows.Next() {
+ var id int64
+ var roomID string
+ var ts gomatrixserverlib.Timestamp
+ var read bool
+ var jsonStr string
+ err = rows.Scan(
+ &id,
+ &roomID,
+ &ts,
+ &read,
+ &jsonStr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ var n api.Notification
+ err := json.Unmarshal([]byte(jsonStr), &n)
+ if err != nil {
+ return nil, 0, err
+ }
+ n.RoomID = roomID
+ n.TS = ts
+ n.Read = read
+ notifs = append(notifs, &n)
+
+ if maxID < id {
+ maxID = id
+ }
+ }
+ return notifs, maxID, rows.Err()
+}
+
+func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
+
+ if err != nil {
+ return 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var count int64
+ if err := rows.Scan(&count); err != nil {
+ return 0, err
+ }
+
+ return count, nil
+ }
+ return 0, rows.Err()
+}
+
+func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
+
+ if err != nil {
+ return 0, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
+
+ if rows.Next() {
+ var total, highlight int64
+ if err := rows.Scan(&total, &highlight); err != nil {
+ return 0, 0, err
+ }
+
+ return total, highlight, nil
+ }
+ return 0, 0, rows.Err()
+}
diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go
new file mode 100644
index 00000000..e718792e
--- /dev/null
+++ b/userapi/storage/sqlite3/pusher_table.go
@@ -0,0 +1,157 @@
+// Copyright 2021 Dan Peleg <dan@globekeeper.com>
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/sirupsen/logrus"
+)
+
+// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
+const pushersSchema = `
+CREATE TABLE IF NOT EXISTS userapi_pushers (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The Matrix user ID localpart for this pusher
+ localpart TEXT NOT NULL,
+ session_id BIGINT DEFAULT NULL,
+ profile_tag TEXT,
+ kind TEXT NOT NULL,
+ app_id TEXT NOT NULL,
+ app_display_name TEXT NOT NULL,
+ device_display_name TEXT NOT NULL,
+ pushkey TEXT NOT NULL,
+ pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
+ lang TEXT NOT NULL,
+ data TEXT NOT NULL
+);
+
+-- For faster deleting by app_id, pushkey pair.
+CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
+
+-- For faster retrieving by localpart.
+CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
+
+-- Pushkey must be unique for a given user and app.
+CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
+`
+
+const insertPusherSQL = "" +
+ "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
+ "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
+
+const selectPushersSQL = "" +
+ "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
+
+const deletePusherSQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
+
+const deletePushersByAppIdAndPushKeySQL = "" +
+ "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
+
+func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) {
+ s := &pushersStatements{}
+ _, err := db.Exec(pushersSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.insertPusherStmt, insertPusherSQL},
+ {&s.selectPushersStmt, selectPushersSQL},
+ {&s.deletePusherStmt, deletePusherSQL},
+ {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
+ }.Prepare(db)
+}
+
+type pushersStatements struct {
+ insertPusherStmt *sql.Stmt
+ selectPushersStmt *sql.Stmt
+ deletePusherStmt *sql.Stmt
+ deletePushersByAppIdAndPushKeyStmt *sql.Stmt
+}
+
+// insertPusher creates a new pusher.
+// Returns an error if the user already has a pusher with the given pusher pushkey.
+// Returns nil error success.
+func (s *pushersStatements) InsertPusher(
+ ctx context.Context, txn *sql.Tx, session_id int64,
+ pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
+) error {
+ _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
+ logrus.Debugf("Created pusher %d", session_id)
+ return err
+}
+
+func (s *pushersStatements) SelectPushers(
+ ctx context.Context, txn *sql.Tx, localpart string,
+) ([]api.Pusher, error) {
+ pushers := []api.Pusher{}
+ rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
+
+ if err != nil {
+ return pushers, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
+
+ for rows.Next() {
+ var pusher api.Pusher
+ var data []byte
+ err = rows.Scan(
+ &pusher.SessionID,
+ &pusher.PushKey,
+ &pusher.PushKeyTS,
+ &pusher.Kind,
+ &pusher.AppID,
+ &pusher.AppDisplayName,
+ &pusher.DeviceDisplayName,
+ &pusher.ProfileTag,
+ &pusher.Language,
+ &data)
+ if err != nil {
+ return pushers, err
+ }
+ err := json.Unmarshal(data, &pusher.Data)
+ if err != nil {
+ return pushers, err
+ }
+ pushers = append(pushers, pusher)
+ }
+
+ logrus.Debugf("Database returned %d pushers", len(pushers))
+ return pushers, rows.Err()
+}
+
+// deletePusher removes a single pusher by pushkey and user localpart.
+func (s *pushersStatements) DeletePusher(
+ ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
+) error {
+ _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart)
+ return err
+}
+
+func (s *pushersStatements) DeletePushers(
+ ctx context.Context, txn *sql.Tx, appid, pushkey string,
+) error {
+ _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey)
+ return err
+}
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 98c24497..b5bb96c4 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
}
+ pusherTable, err := NewSQLitePusherTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
+ }
+ notificationsTable, err := NewSQLiteNotificationTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
+ }
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
+ Pushers: pusherTable,
+ Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 12939ced..815e5119 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
)
type AccountDataTable interface {
@@ -93,3 +94,42 @@ type ThreePIDTable interface {
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}
+
+type PusherTable interface {
+ InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
+ SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
+ DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
+ DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
+}
+
+type NotificationTable interface {
+ Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error
+ DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error)
+ UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error)
+ Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
+ SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
+ SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
+}
+
+type NotificationFilter uint32
+
+const (
+ // HighlightNotifications returns notifications that had a
+ // "highlight" tweak assigned to them from evaluating push rules.
+ HighlightNotifications NotificationFilter = 1 << iota
+
+ // NonHighlightNotifications returns notifications that don't
+ // match HighlightNotifications.
+ NonHighlightNotifications
+
+ // NoNotifications is a filter to exclude all types of
+ // notifications. It's useful as a zero value, but isn't likely to
+ // be used in a call to Notifications.Select*.
+ NoNotifications NotificationFilter = 0
+
+ // AllNotifications is a filter to include all types of
+ // notifications in Notifications.Select*. Note that PostgreSQL
+ // balks if this doesn't fit in INTEGER, even though we use
+ // uint32.
+ AllNotifications NotificationFilter = (1 << 31) - 1
+)