aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2022-10-21 10:48:25 +0200
committerGitHub <noreply@github.com>2022-10-21 10:48:25 +0200
commite57b30172227c4a0b7de15ba635b20921dedda5e (patch)
treefa73a898b6a58d2250e2e7c6fd7b7f5df8b0f7f2
parent40cfb9a4ea23f1c9214553255feb296c2578b213 (diff)
Set `display_name` and/or `avatar_url` for server notices (#2820)
This should fix #2815 by making sure we actually set the `display_name` and/or `avatar_url` and create the needed membership event. To avoid creating a new membership event when starting Dendrite, `SetAvatarURL` and `SetDisplayName` now return a `Changed` value, which also makes the regular endpoints idempotent.
-rw-r--r--clientapi/routing/profile.go127
-rw-r--r--clientapi/routing/routing.go2
-rw-r--r--clientapi/routing/server_notices.go31
-rw-r--r--userapi/api/api.go12
-rw-r--r--userapi/api/api_trace.go2
-rw-r--r--userapi/internal/api.go14
-rw-r--r--userapi/inthttp/client.go2
-rw-r--r--userapi/storage/interface.go4
-rw-r--r--userapi/storage/postgres/profile_table.go36
-rw-r--r--userapi/storage/shared/storage.go16
-rw-r--r--userapi/storage/sqlite3/profile_table.go40
-rw-r--r--userapi/storage/storage_test.go20
-rw-r--r--userapi/storage/tables/interface.go4
-rw-r--r--userapi/userapi_test.go9
14 files changed, 189 insertions, 130 deletions
diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go
index 0685c735..c9647eb1 100644
--- a/clientapi/routing/profile.go
+++ b/clientapi/routing/profile.go
@@ -19,6 +19,8 @@ import (
"net/http"
"time"
+ "github.com/matrix-org/gomatrixserverlib"
+
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
@@ -27,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/util"
@@ -126,20 +127,6 @@ func SetAvatarURL(
}
}
- res := &userapi.QueryProfileResponse{}
- err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
- UserID: userID,
- }, res)
- if err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
- return jsonerror.InternalServerError()
- }
- oldProfile := &authtypes.Profile{
- Localpart: localpart,
- DisplayName: res.DisplayName,
- AvatarURL: res.AvatarURL,
- }
-
setRes := &userapi.PerformSetAvatarURLResponse{}
if err = profileAPI.SetAvatarURL(req.Context(), &userapi.PerformSetAvatarURLRequest{
Localpart: localpart,
@@ -148,41 +135,17 @@ func SetAvatarURL(
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetAvatarURL failed")
return jsonerror.InternalServerError()
}
-
- var roomsRes api.QueryRoomsForUserResponse
- err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
- UserID: device.UserID,
- WantMembership: "join",
- }, &roomsRes)
- if err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
- return jsonerror.InternalServerError()
- }
-
- newProfile := authtypes.Profile{
- Localpart: localpart,
- DisplayName: oldProfile.DisplayName,
- AvatarURL: r.AvatarURL,
- }
-
- events, err := buildMembershipEvents(
- req.Context(), roomsRes.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
- )
- switch e := err.(type) {
- case nil:
- case gomatrixserverlib.BadJSONError:
+ // No need to build new membership events, since nothing changed
+ if !setRes.Changed {
return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: jsonerror.BadJSON(e.Error()),
+ Code: http.StatusOK,
+ JSON: struct{}{},
}
- default:
- util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
- return jsonerror.InternalServerError()
}
- if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
- return jsonerror.InternalServerError()
+ response, err := updateProfile(req.Context(), rsAPI, device, setRes.Profile, userID, cfg, evTime)
+ if err != nil {
+ return response
}
return util.JSONResponse{
@@ -255,47 +218,51 @@ func SetDisplayName(
}
}
- pRes := &userapi.QueryProfileResponse{}
- err = profileAPI.QueryProfile(req.Context(), &userapi.QueryProfileRequest{
- UserID: userID,
- }, pRes)
- if err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("profileAPI.QueryProfile failed")
- return jsonerror.InternalServerError()
- }
- oldProfile := &authtypes.Profile{
- Localpart: localpart,
- DisplayName: pRes.DisplayName,
- AvatarURL: pRes.AvatarURL,
- }
-
+ profileRes := &userapi.PerformUpdateDisplayNameResponse{}
err = profileAPI.SetDisplayName(req.Context(), &userapi.PerformUpdateDisplayNameRequest{
Localpart: localpart,
DisplayName: r.DisplayName,
- }, &struct{}{})
+ }, profileRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("profileAPI.SetDisplayName failed")
return jsonerror.InternalServerError()
}
+ // No need to build new membership events, since nothing changed
+ if !profileRes.Changed {
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: struct{}{},
+ }
+ }
+
+ response, err := updateProfile(req.Context(), rsAPI, device, profileRes.Profile, userID, cfg, evTime)
+ if err != nil {
+ return response
+ }
+
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: struct{}{},
+ }
+}
+func updateProfile(
+ ctx context.Context, rsAPI api.ClientRoomserverAPI, device *userapi.Device,
+ profile *authtypes.Profile,
+ userID string, cfg *config.ClientAPI, evTime time.Time,
+) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse
- err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
+ err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)
if err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
- return jsonerror.InternalServerError()
- }
-
- newProfile := authtypes.Profile{
- Localpart: localpart,
- DisplayName: r.DisplayName,
- AvatarURL: oldProfile.AvatarURL,
+ util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
+ return jsonerror.InternalServerError(), err
}
events, err := buildMembershipEvents(
- req.Context(), res.RoomIDs, newProfile, userID, cfg, evTime, rsAPI,
+ ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI,
)
switch e := err.(type) {
case nil:
@@ -303,21 +270,17 @@ func SetDisplayName(
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(e.Error()),
- }
+ }, e
default:
- util.GetLogger(req.Context()).WithError(err).Error("buildMembershipEvents failed")
- return jsonerror.InternalServerError()
- }
-
- if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
- return jsonerror.InternalServerError()
+ util.GetLogger(ctx).WithError(err).Error("buildMembershipEvents failed")
+ return jsonerror.InternalServerError(), e
}
- return util.JSONResponse{
- Code: http.StatusOK,
- JSON: struct{}{},
+ if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, true); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
+ return jsonerror.InternalServerError(), err
}
+ return util.JSONResponse{}, nil
}
// getProfile gets the full profile of a user by querying the database or a
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index ec5ca899..4ca8e59c 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -178,7 +178,7 @@ func Setup(
// server notifications
if cfg.Matrix.ServerNotices.Enabled {
logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice")
- serverNotificationSender, err := getSenderDevice(context.Background(), userAPI, cfg)
+ serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg)
if err != nil {
logrus.WithError(err).Fatal("unable to get account for sending sending server notices")
}
diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go
index 7729eddd..a6a78061 100644
--- a/clientapi/routing/server_notices.go
+++ b/clientapi/routing/server_notices.go
@@ -277,6 +277,7 @@ func (r sendServerNoticeRequest) valid() (ok bool) {
// It returns an userapi.Device, which is used for building the event
func getSenderDevice(
ctx context.Context,
+ rsAPI api.ClientRoomserverAPI,
userAPI userapi.ClientUserAPI,
cfg *config.ClientAPI,
) (*userapi.Device, error) {
@@ -291,16 +292,32 @@ func getSenderDevice(
return nil, err
}
- // set the avatarurl for the user
- res := &userapi.PerformSetAvatarURLResponse{}
+ // Set the avatarurl for the user
+ avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
- }, res); err != nil {
+ }, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
return nil, err
}
+ profile := avatarRes.Profile
+
+ // Set the displayname for the user
+ displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
+ if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
+ Localpart: cfg.Matrix.ServerNotices.LocalPart,
+ DisplayName: cfg.Matrix.ServerNotices.DisplayName,
+ }, displayNameRes); err != nil {
+ util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
+ return nil, err
+ }
+
+ if displayNameRes.Changed {
+ profile.DisplayName = cfg.Matrix.ServerNotices.DisplayName
+ }
+
// Check if we got existing devices
deviceRes := &userapi.QueryDevicesResponse{}
err = userAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{
@@ -310,7 +327,15 @@ func getSenderDevice(
return nil, err
}
+ // We've got an existing account, return the first device of it
if len(deviceRes.Devices) > 0 {
+ // If there were changes to the profile, create a new membership event
+ if displayNameRes.Changed || avatarRes.Changed {
+ _, err = updateProfile(ctx, rsAPI, &deviceRes.Devices[0], profile, accRes.Account.UserID, cfg, time.Now())
+ if err != nil {
+ return nil, err
+ }
+ }
return &deviceRes.Devices[0], nil
}
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 66ee9c7c..eef29144 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -96,7 +96,7 @@ type ClientUserAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
- SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error
+ SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
@@ -579,7 +579,10 @@ type Notification struct {
type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string
}
-type PerformSetAvatarURLResponse struct{}
+type PerformSetAvatarURLResponse struct {
+ Profile *authtypes.Profile `json:"profile"`
+ Changed bool `json:"changed"`
+}
type QueryNumericLocalpartResponse struct {
ID int64
@@ -606,6 +609,11 @@ type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string
}
+type PerformUpdateDisplayNameResponse struct {
+ Profile *authtypes.Profile `json:"profile"`
+ Changed bool `json:"changed"`
+}
+
type QueryLocalpartForThreePIDRequest struct {
ThreePID, Medium string
}
diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go
index 7e2f6961..90834f7e 100644
--- a/userapi/api/api_trace.go
+++ b/userapi/api/api_trace.go
@@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req
return err
}
-func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error {
+func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error {
err := t.Impl.SetDisplayName(ctx, req, res)
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
return err
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 2f7795df..63044eed 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
- if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
+ if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@@ -813,7 +813,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
- return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
+ profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
+ res.Profile = profile
+ res.Changed = changed
+ return err
}
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
@@ -847,8 +850,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
}
}
-func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error {
- return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
+func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
+ profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
+ res.Profile = profile
+ res.Changed = changed
+ return err
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index a375d6ca..aa5d46d9 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword(
func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context,
request *api.PerformUpdateDisplayNameRequest,
- response *struct{},
+ response *api.PerformUpdateDisplayNameResponse,
) error {
return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index 02efe7af..fb12b53a 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -29,8 +29,8 @@ import (
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
- SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
- SetDisplayName(ctx context.Context, localpart string, displayName string) error
+ SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
+ SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
}
type Account interface {
diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go
index f686127b..2753b23d 100644
--- a/userapi/storage/postgres/profile_table.go
+++ b/userapi/storage/postgres/profile_table.go
@@ -44,10 +44,18 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
- "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles AS new" +
+ " SET avatar_url = $1" +
+ " FROM userapi_profiles AS old" +
+ " WHERE new.localpart = $2" +
+ " RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" +
- "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles AS new" +
+ " SET display_name = $1" +
+ " FROM userapi_profiles AS old" +
+ " WHERE new.localpart = $2" +
+ " RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
-) (err error) {
- _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
- return
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ AvatarURL: avatarURL,
+ }
+ var changed bool
+ stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
+ err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
+ return profile, changed, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
-) (err error) {
- _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
- return
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ DisplayName: displayName,
+ }
+ var changed bool
+ stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
+ err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
+ return profile, changed, err
}
func (s *profilesStatements) SelectProfilesBySearch(
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 4e28f7b5..f8b6ad31 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
+) (profile *authtypes.Profile, changed bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
+ return err
})
+ return
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
-) error {
- return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
+) (profile *authtypes.Profile, changed bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
+ return err
})
+ return
}
// SetPassword sets the account password to the given hash.
diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go
index 267daf04..b6130a1e 100644
--- a/userapi/storage/sqlite3/profile_table.go
+++ b/userapi/storage/sqlite3/profile_table.go
@@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
- "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
+ " RETURNING display_name"
const setDisplayNameSQL = "" +
- "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
+ "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
+ " RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
-) (err error) {
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ AvatarURL: avatarURL,
+ }
+ old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ if err != nil {
+ return old, false, err
+ }
+ if old.AvatarURL == avatarURL {
+ return old, false, nil
+ }
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
- _, err = stmt.ExecContext(ctx, avatarURL, localpart)
- return
+ err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
+ return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
-) (err error) {
+) (*authtypes.Profile, bool, error) {
+ profile := &authtypes.Profile{
+ Localpart: localpart,
+ DisplayName: displayName,
+ }
+ old, err := s.SelectProfileByLocalpart(ctx, localpart)
+ if err != nil {
+ return old, false, err
+ }
+ if old.DisplayName == displayName {
+ return old, false, nil
+ }
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
- _, err = stmt.ExecContext(ctx, displayName, localpart)
- return
+ err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
+ return profile, true, err
}
func (s *profilesStatements) SelectProfilesBySearch(
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index 8e5b32b6..354f085f 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -382,15 +382,23 @@ func Test_Profile(t *testing.T) {
// set avatar & displayname
wantProfile.DisplayName = "Alice"
- wantProfile.AvatarURL = "mxc://aliceAvatar"
- err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
+ gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
+ assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
- err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
+ assert.True(t, changed)
+
+ wantProfile.AvatarURL = "mxc://aliceAvatar"
+ gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
+ assert.NoError(t, err, "unable to set avatar url")
+ assert.Equal(t, wantProfile, gotProfile)
+ assert.True(t, changed)
+
+ // Setting the same avatar again doesn't change anything
+ wantProfile.AvatarURL = "mxc://aliceAvatar"
+ gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
- // verify profile
- gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
- assert.NoError(t, err, "unable to get profile by localpart")
assert.Equal(t, wantProfile, gotProfile)
+ assert.False(t, changed)
// search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index cc428799..1b239e44 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
- SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
- SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
+ SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
+ SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 4417f4dc..aaa93f45 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -23,13 +23,14 @@ import (
"time"
"github.com/gorilla/mux"
+ "github.com/matrix-org/gomatrixserverlib"
+ "golang.org/x/crypto/bcrypt"
+
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
- "github.com/matrix-org/gomatrixserverlib"
- "golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
@@ -83,10 +84,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
- if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
+ if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err)
}
- if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
+ if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err)
}