diff options
Diffstat (limited to 'userapi/storage/accounts/postgres')
-rw-r--r-- | userapi/storage/accounts/postgres/profile_table.go | 31 | ||||
-rw-r--r-- | userapi/storage/accounts/postgres/storage.go | 7 |
2 files changed, 38 insertions, 0 deletions
diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go index d2cbeb8e..14b12c35 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -17,8 +17,10 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/internal" ) const profilesSchema = ` @@ -45,11 +47,15 @@ const setAvatarURLSQL = "" + const setDisplayNameSQL = "" + "UPDATE account_profiles SET display_name = $1 WHERE localpart = $2" +const selectProfilesBySearchSQL = "" + + "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + type profilesStatements struct { insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt setDisplayNameStmt *sql.Stmt + selectProfilesBySearchStmt *sql.Stmt } func (s *profilesStatements) prepare(db *sql.DB) (err error) { @@ -69,6 +75,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { return } + if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { + return + } return } @@ -105,3 +114,25 @@ func (s *profilesStatements) setDisplayName( _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) return } + +func (s *profilesStatements) selectProfilesBySearch( + ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + var profiles []authtypes.Profile + // The fmt.Sprintf directive below is building a parameter for the + // "LIKE" condition in the SQL query. %% escapes the % char, so the + // statement in the end will look like "LIKE %searchString%". + rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") + for rows.Next() { + var profile authtypes.Profile + if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + return nil, err + } + profiles = append(profiles, profile) + } + return profiles, nil +} diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index c76b92f1..f56fb6d8 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -298,3 +298,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, ) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } + +// SearchProfiles returns all profiles where the provided localpart or display name +// match any part of the profiles in the database. +func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int, +) ([]authtypes.Profile, error) { + return d.profiles.selectProfilesBySearch(ctx, searchString, limit) +} |