diff options
Diffstat (limited to 'userapi/storage')
-rw-r--r-- | userapi/storage/interface.go | 13 | ||||
-rw-r--r-- | userapi/storage/postgres/notifications_table.go | 219 | ||||
-rw-r--r-- | userapi/storage/postgres/pusher_table.go | 157 | ||||
-rw-r--r-- | userapi/storage/postgres/storage.go | 10 | ||||
-rw-r--r-- | userapi/storage/shared/storage.go | 109 | ||||
-rw-r--r-- | userapi/storage/sqlite3/notifications_table.go | 219 | ||||
-rw-r--r-- | userapi/storage/sqlite3/pusher_table.go | 157 | ||||
-rw-r--r-- | userapi/storage/sqlite3/storage.go | 10 | ||||
-rw-r--r-- | userapi/storage/tables/interface.go | 40 |
9 files changed, 925 insertions, 9 deletions
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index a131dac4..6d22fea9 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" ) type Database interface { @@ -89,6 +90,18 @@ type Database interface { // GetLoginTokenDataByToken returns the data associated with the given token. // May return sql.ErrNoRows. GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) + + InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + + UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error + GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + RemovePushers(ctx context.Context, appid, pushkey string) error } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go new file mode 100644 index 00000000..7bcc0f9c --- /dev/null +++ b/userapi/storage/postgres/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id BIGSERIAL PRIMARY KEY, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go new file mode 100644 index 00000000..670dc916 --- /dev/null +++ b/userapi/storage/postgres/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id BIGSERIAL PRIMARY KEY, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index ac5c59b8..c74a999f 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) } + pusherTable, err := NewPostgresPusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewPostgresNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewDummyWriter(), diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 5f1f9500..a58974b4 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -29,6 +29,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" @@ -47,6 +48,8 @@ type Database struct { KeyBackupVersions tables.KeyBackupVersionTable Devices tables.DevicesTable LoginTokens tables.LoginTokenTable + Notifications tables.NotificationTable + Pushers tables.PusherTable LoginTokenLifetime time.Duration ServerName gomatrixserverlib.ServerName BcryptCost int @@ -160,15 +163,12 @@ func (d *Database) createAccount( if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { return nil, err } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ - "global": { - "content": [], - "override": [], - "room": [], - "sender": [], - "underride": [] - } - }`)); err != nil { + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + prbs, err := json.Marshal(pushRuleSets) + if err != nil { + return nil, err + } + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { return nil, err } return account, nil @@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { return d.LoginTokens.SelectLoginToken(ctx, token) } + +func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + }) +} + +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + return err + }) + return +} + +func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + return err + }) + return +} + +func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +} + +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, filter) +} + +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +} + +func (d *Database) UpsertPusher( + ctx context.Context, p api.Pusher, localpart string, +) error { + data, err := json.Marshal(p.Data) + if err != nil { + return err + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.Pushers.InsertPusher( + ctx, txn, + p.SessionID, + p.PushKey, + p.PushKeyTS, + p.Kind, + p.AppID, + p.AppDisplayName, + p.DeviceDisplayName, + p.ProfileTag, + p.Language, + string(data), + localpart) + }) +} + +// GetPushers returns the pushers matching the given localpart. +func (d *Database) GetPushers( + ctx context.Context, localpart string, +) ([]api.Pusher, error) { + return d.Pushers.SelectPushers(ctx, nil, localpart) +} + +// RemovePusher deletes one pusher +// Invoked when `append` is true and `kind` is null in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePusher( + ctx context.Context, appid, pushkey, localpart string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + if err == sql.ErrNoRows { + return nil + } + return err + }) +} + +// RemovePushers deletes all pushers that match given App Id and Push Key pair. +// Invoked when `append` parameter is false in +// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set +func (d *Database) RemovePushers( + ctx context.Context, appid, pushkey string, +) error { + return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { + return d.Pushers.DeletePushers(ctx, txn, appid, pushkey) + }) +} diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go new file mode 100644 index 00000000..fcfb1aad --- /dev/null +++ b/userapi/storage/sqlite3/notifications_table.go @@ -0,0 +1,219 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + log "github.com/sirupsen/logrus" +) + +type notificationsStatements struct { + insertStmt *sql.Stmt + deleteUpToStmt *sql.Stmt + updateReadStmt *sql.Stmt + selectStmt *sql.Stmt + selectCountStmt *sql.Stmt + selectRoomCountsStmt *sql.Stmt +} + +const notificationSchema = ` +CREATE TABLE IF NOT EXISTS userapi_notifications ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + localpart TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + stream_pos BIGINT NOT NULL, + ts_ms BIGINT NOT NULL, + highlight BOOLEAN NOT NULL, + notification_json TEXT NOT NULL, + read BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +` + +const insertNotificationSQL = "" + + "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + +const deleteNotificationsUpToSQL = "" + + "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + +const updateNotificationReadSQL = "" + + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + +const selectNotificationSQL = "" + + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $4" + +const selectNotificationCountSQL = "" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + + "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read" + +const selectRoomNotificationCountsSQL = "" + + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + + "WHERE localpart = $1 AND room_id = $2 AND NOT read" + +func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) { + s := ¬ificationsStatements{} + _, err := db.Exec(notificationSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertStmt, insertNotificationSQL}, + {&s.deleteUpToStmt, deleteNotificationsUpToSQL}, + {&s.updateReadStmt, updateNotificationReadSQL}, + {&s.selectStmt, selectNotificationSQL}, + {&s.selectCountStmt, selectNotificationCountSQL}, + {&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL}, + }.Prepare(db) +} + +// Insert inserts a notification into the database. +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error { + roomID, tsMS := n.RoomID, n.TS + nn := *n + // Clears out fields that have their own columns to (1) shrink the + // data and (2) avoid difficult-to-debug inconsistency bugs. + nn.RoomID = "" + nn.TS, nn.Read = 0, false + bs, err := json.Marshal(nn) + if err != nil { + return err + } + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + return err +} + +// DeleteUpTo deletes all previous notifications, up to and including the event. +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows) + return nrows > 0, nil +} + +// UpdateRead updates the "read" value for an event. +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) + if err != nil { + return false, err + } + nrows, err := res.RowsAffected() + if err != nil { + return true, err + } + log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows) + return nrows > 0, nil +} + +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) + + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + var maxID int64 = -1 + var notifs []*api.Notification + for rows.Next() { + var id int64 + var roomID string + var ts gomatrixserverlib.Timestamp + var read bool + var jsonStr string + err = rows.Scan( + &id, + &roomID, + &ts, + &read, + &jsonStr) + if err != nil { + return nil, 0, err + } + + var n api.Notification + err := json.Unmarshal([]byte(jsonStr), &n) + if err != nil { + return nil, 0, err + } + n.RoomID = roomID + n.TS = ts + n.Read = read + notifs = append(notifs, &n) + + if maxID < id { + maxID = id + } + } + return notifs, maxID, rows.Err() +} + +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter)) + + if err != nil { + return 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var count int64 + if err := rows.Scan(&count); err != nil { + return 0, err + } + + return count, nil + } + return 0, rows.Err() +} + +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) { + rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID) + + if err != nil { + return 0, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed") + + if rows.Next() { + var total, highlight int64 + if err := rows.Scan(&total, &highlight); err != nil { + return 0, 0, err + } + + return total, highlight, nil + } + return 0, 0, rows.Err() +} diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go new file mode 100644 index 00000000..e718792e --- /dev/null +++ b/userapi/storage/sqlite3/pusher_table.go @@ -0,0 +1,157 @@ +// Copyright 2021 Dan Peleg <dan@globekeeper.com> +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers +const pushersSchema = ` +CREATE TABLE IF NOT EXISTS userapi_pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The Matrix user ID localpart for this pusher + localpart TEXT NOT NULL, + session_id BIGINT DEFAULT NULL, + profile_tag TEXT, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + pushkey_ts_ms BIGINT NOT NULL DEFAULT 0, + lang TEXT NOT NULL, + data TEXT NOT NULL +); + +-- For faster deleting by app_id, pushkey pair. +CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); + +-- For faster retrieving by localpart. +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); + +-- Pushkey must be unique for a given user and app. +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +` + +const insertPusherSQL = "" + + "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + + "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + +const selectPushersSQL = "" + + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + +const deletePusherSQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + +const deletePushersByAppIdAndPushKeySQL = "" + + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" + +func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) { + s := &pushersStatements{} + _, err := db.Exec(pushersSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.insertPusherStmt, insertPusherSQL}, + {&s.selectPushersStmt, selectPushersSQL}, + {&s.deletePusherStmt, deletePusherSQL}, + {&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL}, + }.Prepare(db) +} + +type pushersStatements struct { + insertPusherStmt *sql.Stmt + selectPushersStmt *sql.Stmt + deletePusherStmt *sql.Stmt + deletePushersByAppIdAndPushKeyStmt *sql.Stmt +} + +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) InsertPusher( + ctx context.Context, txn *sql.Tx, session_id int64, + pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, +) error { + _, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) + logrus.Debugf("Created pusher %d", session_id) + return err +} + +func (s *pushersStatements) SelectPushers( + ctx context.Context, txn *sql.Tx, localpart string, +) ([]api.Pusher, error) { + pushers := []api.Pusher{} + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + + if err != nil { + return pushers, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed") + + for rows.Next() { + var pusher api.Pusher + var data []byte + err = rows.Scan( + &pusher.SessionID, + &pusher.PushKey, + &pusher.PushKeyTS, + &pusher.Kind, + &pusher.AppID, + &pusher.AppDisplayName, + &pusher.DeviceDisplayName, + &pusher.ProfileTag, + &pusher.Language, + &data) + if err != nil { + return pushers, err + } + err := json.Unmarshal(data, &pusher.Data) + if err != nil { + return pushers, err + } + pushers = append(pushers, pusher) + } + + logrus.Debugf("Database returned %d pushers", len(pushers)) + return pushers, rows.Err() +} + +// deletePusher removes a single pusher by pushkey and user localpart. +func (s *pushersStatements) DeletePusher( + ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, +) error { + _, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart) + return err +} + +func (s *pushersStatements) DeletePushers( + ctx context.Context, txn *sql.Tx, appid, pushkey string, +) error { + _, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey) + return err +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 98c24497..b5bb96c4 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) } + pusherTable, err := NewSQLitePusherTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresPusherTable: %w", err) + } + notificationsTable, err := NewSQLiteNotificationTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err) + } return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, @@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver OpenIDTokens: openIDTable, Profiles: profilesTable, ThreePIDs: threePIDTable, + Pushers: pusherTable, + Notifications: notificationsTable, ServerName: serverName, DB: db, Writer: sqlutil.NewExclusiveWriter(), diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 12939ced..815e5119 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) type AccountDataTable interface { @@ -93,3 +94,42 @@ type ThreePIDTable interface { InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } + +type PusherTable interface { + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error +} + +type NotificationTable interface { + Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) +} + +type NotificationFilter uint32 + +const ( + // HighlightNotifications returns notifications that had a + // "highlight" tweak assigned to them from evaluating push rules. + HighlightNotifications NotificationFilter = 1 << iota + + // NonHighlightNotifications returns notifications that don't + // match HighlightNotifications. + NonHighlightNotifications + + // NoNotifications is a filter to exclude all types of + // notifications. It's useful as a zero value, but isn't likely to + // be used in a call to Notifications.Select*. + NoNotifications NotificationFilter = 0 + + // AllNotifications is a filter to include all types of + // notifications in Notifications.Select*. Note that PostgreSQL + // balks if this doesn't fit in INTEGER, even though we use + // uint32. + AllNotifications NotificationFilter = (1 << 31) - 1 +) |