diff options
author | Kegan Dougal <kegan@matrix.org> | 2021-07-28 18:30:04 +0100 |
---|---|---|
committer | Kegan Dougal <kegan@matrix.org> | 2021-07-28 18:30:04 +0100 |
commit | ed4097825bc65f2332bcdc975ed201841221ff7c (patch) | |
tree | 4ac6270fa282b1d13ac73c97e4078f1961dbdf4e /userapi | |
parent | 9e4618000e0347741eac1279bf6c94c3b9980785 (diff) |
Factor out StatementList to `sqlutil` and use it in `userapi`
It helps with the boilerplate.
Diffstat (limited to 'userapi')
14 files changed, 98 insertions, 208 deletions
diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 09eb2611..8ba890e7 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -61,16 +61,11 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return - } - if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { - return - } - if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectAccountDataStmt, selectAccountDataSQL}, + {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, + }.Prepare(db) } func (s *accountDataStatements) insertAccountData( diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 4eaa5b58..b57aa901 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -81,26 +81,15 @@ func (s *accountsStatements) execSchema(db *sql.DB) error { } func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { - return - } - if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { - return - } - if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { - return - } - if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { - return - } - if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { - return - } - if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertAccountStmt, insertAccountSQL}, + {&s.updatePasswordStmt, updatePasswordSQL}, + {&s.deactivateAccountStmt, deactivateAccountSQL}, + {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, + {&s.selectPasswordHashStmt, selectPasswordHashSQL}, + {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + }.Prepare(db) } // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go index 0a2a2655..c1402d4d 100644 --- a/userapi/storage/accounts/postgres/key_backup_table.go +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" ) @@ -76,25 +77,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { - return - } - if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { - return - } - if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { - return - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return - } - if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { - return - } - if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertBackupKeyStmt, insertBackupKeySQL}, + {&s.updateBackupKeyStmt, updateBackupKeySQL}, + {&s.countKeysStmt, countKeysSQL}, + {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, + {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, + }.Prepare(db) } func (s keyBackupStatements) countKeys( diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go index 51a462b3..d73447b4 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -20,6 +20,8 @@ import ( "encoding/json" "fmt" "strconv" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const keyBackupVersionTableSchema = ` @@ -72,25 +74,14 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { - return - } - if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { - return - } - if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { - return - } - if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { - return - } - if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { - return - } - if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertKeyBackupStmt, insertKeyBackupSQL}, + {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, + {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, + {&s.selectKeyBackupStmt, selectKeyBackupSQL}, + {&s.selectLatestVersionStmt, selectLatestVersionSQL}, + {&s.updateKeyBackupETagStmt, updateKeyBackupETagSQL}, + }.Prepare(db) } func (s *keyBackupVersionStatements) insertKeyBackup( diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/accounts/postgres/openid_table.go index 86c19705..190d141b 100644 --- a/userapi/storage/accounts/postgres/openid_table.go +++ b/userapi/storage/accounts/postgres/openid_table.go @@ -39,14 +39,11 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam if err != nil { return } - if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { - return - } - if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertTokenStmt, insertTokenSQL}, + {&s.selectTokenStmt, selectTokenSQL}, + }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/accounts/postgres/profile_table.go index 45d802f1..9313864b 100644 --- a/userapi/storage/accounts/postgres/profile_table.go +++ b/userapi/storage/accounts/postgres/profile_table.go @@ -64,22 +64,13 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { - return - } - if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { - return - } - if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { - return - } - if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { - return - } - if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertProfileStmt, insertProfileSQL}, + {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, + {&s.setAvatarURLStmt, setAvatarURLSQL}, + {&s.setDisplayNameStmt, setDisplayNameSQL}, + {&s.selectProfilesBySearchStmt, selectProfilesBySearchSQL}, + }.Prepare(db) } func (s *profilesStatements) insertProfile( diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/accounts/postgres/threepid_table.go index 7de96350..9280fc87 100644 --- a/userapi/storage/accounts/postgres/threepid_table.go +++ b/userapi/storage/accounts/postgres/threepid_table.go @@ -63,20 +63,12 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { - return - } - if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { - return - } - if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { - return - } - if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { - return - } - - return + return sqlutil.StatementList{ + {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, + {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, + {&s.insertThreePIDStmt, insertThreePIDSQL}, + {&s.deleteThreePIDStmt, deleteThreePIDSQL}, + }.Prepare(db) } func (s *threepidStatements) selectLocalpartForThreePID( diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 870a3706..871f996e 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -62,16 +62,11 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { - return - } - if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { - return - } - if s.selectAccountDataByTypeStmt, err = db.Prepare(selectAccountDataByTypeSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertAccountDataStmt, insertAccountDataSQL}, + {&s.selectAccountDataStmt, selectAccountDataSQL}, + {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, + }.Prepare(db) } func (s *accountDataStatements) insertAccountData( diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 50f07237..8a7c8fba 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -81,26 +81,15 @@ func (s *accountsStatements) execSchema(db *sql.DB) error { func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db - if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { - return - } - if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { - return - } - if s.deactivateAccountStmt, err = db.Prepare(deactivateAccountSQL); err != nil { - return - } - if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { - return - } - if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { - return - } - if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertAccountStmt, insertAccountSQL}, + {&s.updatePasswordStmt, updatePasswordSQL}, + {&s.deactivateAccountStmt, deactivateAccountSQL}, + {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, + {&s.selectPasswordHashStmt, selectPasswordHashSQL}, + {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + }.Prepare(db) } // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go index 67509351..837d38cf 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" ) @@ -76,25 +77,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertBackupKeyStmt, err = db.Prepare(insertBackupKeySQL); err != nil { - return - } - if s.updateBackupKeyStmt, err = db.Prepare(updateBackupKeySQL); err != nil { - return - } - if s.countKeysStmt, err = db.Prepare(countKeysSQL); err != nil { - return - } - if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { - return - } - if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { - return - } - if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertBackupKeyStmt, insertBackupKeySQL}, + {&s.updateBackupKeyStmt, updateBackupKeySQL}, + {&s.countKeysStmt, countKeysSQL}, + {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, + {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, + }.Prepare(db) } func (s keyBackupStatements) countKeys( diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go index a9e7bf5d..4211ed0f 100644 --- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go +++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go @@ -20,6 +20,8 @@ import ( "encoding/json" "fmt" "strconv" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) const keyBackupVersionTableSchema = ` @@ -70,25 +72,14 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil { - return - } - if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil { - return - } - if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil { - return - } - if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil { - return - } - if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil { - return - } - if s.updateKeyBackupETagStmt, err = db.Prepare(updateKeyBackupETagSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertKeyBackupStmt, insertKeyBackupSQL}, + {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, + {&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, + {&s.selectKeyBackupStmt, selectKeyBackupSQL}, + {&s.selectLatestVersionStmt, selectLatestVersionSQL}, + {&s.updateKeyBackupETagStmt, updateKeyBackupETagSQL}, + }.Prepare(db) } func (s *keyBackupVersionStatements) insertKeyBackup( diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/accounts/sqlite3/openid_table.go index 80b9dd4c..98c0488b 100644 --- a/userapi/storage/accounts/sqlite3/openid_table.go +++ b/userapi/storage/accounts/sqlite3/openid_table.go @@ -41,14 +41,11 @@ func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerNam if err != nil { return err } - if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil { - return - } - if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil { - return - } s.serverName = server - return + return sqlutil.StatementList{ + {&s.insertTokenStmt, insertTokenSQL}, + {&s.selectTokenStmt, selectTokenSQL}, + }.Prepare(db) } // insertToken inserts a new OpenID Connect token to the DB. diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index a67e892f..a92e9566 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -66,22 +66,13 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.insertProfileStmt, err = db.Prepare(insertProfileSQL); err != nil { - return - } - if s.selectProfileByLocalpartStmt, err = db.Prepare(selectProfileByLocalpartSQL); err != nil { - return - } - if s.setAvatarURLStmt, err = db.Prepare(setAvatarURLSQL); err != nil { - return - } - if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil { - return - } - if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil { - return - } - return + return sqlutil.StatementList{ + {&s.insertProfileStmt, insertProfileSQL}, + {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, + {&s.setAvatarURLStmt, setAvatarURLSQL}, + {&s.setDisplayNameStmt, setDisplayNameSQL}, + {&s.selectProfilesBySearchStmt, selectProfilesBySearchSQL}, + }.Prepare(db) } func (s *profilesStatements) insertProfile( diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 43112d38..9dc0e2d2 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -66,20 +66,12 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { if err != nil { return } - if s.selectLocalpartForThreePIDStmt, err = db.Prepare(selectLocalpartForThreePIDSQL); err != nil { - return - } - if s.selectThreePIDsForLocalpartStmt, err = db.Prepare(selectThreePIDsForLocalpartSQL); err != nil { - return - } - if s.insertThreePIDStmt, err = db.Prepare(insertThreePIDSQL); err != nil { - return - } - if s.deleteThreePIDStmt, err = db.Prepare(deleteThreePIDSQL); err != nil { - return - } - - return + return sqlutil.StatementList{ + {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, + {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, + {&s.insertThreePIDStmt, insertThreePIDSQL}, + {&s.deleteThreePIDStmt, deleteThreePIDSQL}, + }.Prepare(db) } func (s *threepidStatements) selectLocalpartForThreePID( |