aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-10-21 10:48:25 +0200
committerGitHub <noreply@github.com>2022-10-21 10:48:25 +0200
commite57b30172227c4a0b7de15ba635b20921dedda5e (patch)
treefa73a898b6a58d2250e2e7c6fd7b7f5df8b0f7f2 /userapi/storage
parent40cfb9a4ea23f1c9214553255feb296c2578b213 (diff)
Set `display_name` and/or `avatar_url` for server notices (#2820)
This should fix #2815 by making sure we actually set the `display_name` and/or `avatar_url` and create the needed membership event. To avoid creating a new membership event when starting Dendrite, `SetAvatarURL` and `SetDisplayName` now return a `Changed` value, which also makes the regular endpoints idempotent.
Diffstat (limited to 'userapi/storage')
-rw-r--r--userapi/storage/interface.go4
-rw-r--r--userapi/storage/postgres/profile_table.go36
-rw-r--r--userapi/storage/shared/storage.go16
-rw-r--r--userapi/storage/sqlite3/profile_table.go40
-rw-r--r--userapi/storage/storage_test.go20
-rw-r--r--userapi/storage/tables/interface.go4
6 files changed, 88 insertions, 32 deletions
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index 02efe7af..fb12b53a 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -29,8 +29,8 @@ import (
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
- SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
- SetDisplayName(ctx context.Context, localpart string, displayName string) error
+ SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
+ SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
}
type Account interface {
diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go
index f686127b..2753b23d 100644
--- a/userapi/storage/postgres/profile_table.go
+++ b/userapi/storage/postgres/profile_table.go
@@ -44,10 +44,18 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
- "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles AS new" +
+ " SET avatar_url = $1" +
+ " FROM userapi_profiles AS old" +
+ " WHERE new.localpart = $2" +
+ " RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" +
- "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles AS new" +
+ " SET display_name = $1" +
+ " FROM userapi_profiles AS old" +
+ " WHERE new.localpart = $2" +
+ " RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
-) (err error) {
- _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
- return
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ AvatarURL: avatarURL,
+ }
+ var changed bool
+ stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
+ err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
+ return profile, changed, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
-) (err error) {
- _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
- return
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ DisplayName: displayName,
+ }
+ var changed bool
+ stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
+ err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
+ return profile, changed, err
}
func (s *profilesStatements) SelectProfilesBySearch(
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 4e28f7b5..f8b6ad31 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
+) (profile *authtypes.Profile, changed bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
+ return err
})
+ return
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
+) (profile *authtypes.Profile, changed bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
+ return err
})
+ return
}
// SetPassword sets the account password to the given hash.
diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go
index 267daf04..b6130a1e 100644
--- a/userapi/storage/sqlite3/profile_table.go
+++ b/userapi/storage/sqlite3/profile_table.go
@@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
- "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
+ " RETURNING display_name"
const setDisplayNameSQL = "" +
- "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
+ " RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
-) (err error) {
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ AvatarURL: avatarURL,
+ }
+ old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ if err != nil {
+ return old, false, err
+ }
+ if old.AvatarURL == avatarURL {
+ return old, false, nil
+ }
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
- _, err = stmt.ExecContext(ctx, avatarURL, localpart)
- return
+ err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
+ return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
-) (err error) {
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ DisplayName: displayName,
+ }
+ old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ if err != nil {
+ return old, false, err
+ }
+ if old.DisplayName == displayName {
+ return old, false, nil
+ }
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
- _, err = stmt.ExecContext(ctx, displayName, localpart)
- return
+ err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
+ return profile, true, err
}
func (s *profilesStatements) SelectProfilesBySearch(
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index 8e5b32b6..354f085f 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -382,15 +382,23 @@ func Test_Profile(t *testing.T) {
// set avatar & displayname
wantProfile.DisplayName = "Alice"
- wantProfile.AvatarURL = "mxc://aliceAvatar"
- err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
+ gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
+ assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
- err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
+ assert.True(t, changed)
+
+ wantProfile.AvatarURL = "mxc://aliceAvatar"
+ gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
+ assert.NoError(t, err, "unable to set avatar url")
+ assert.Equal(t, wantProfile, gotProfile)
+ assert.True(t, changed)
+
+ // Setting the same avatar again doesn't change anything
+ wantProfile.AvatarURL = "mxc://aliceAvatar"
+ gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
- // verify profile
- gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
- assert.NoError(t, err, "unable to get profile by localpart")
assert.Equal(t, wantProfile, gotProfile)
+ assert.False(t, changed)
// search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index cc428799..1b239e44 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
- SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
- SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
+ SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
+ SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}