diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-09-01 11:28:35 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-01 11:28:35 +0100 |
commit | 0ab5bccd11eea6063968c60fbdf5b36ade22da81 (patch) | |
tree | 0371a076b5ea9acd6693e27587c77c9bbf287f1b /userapi | |
parent | 3f9b829bc570d5f6353eda21ecf3d0088e4d9c50 (diff) |
Storage tweaks (#1373)
* Sync API tweaks
* User API tweaks
Diffstat (limited to 'userapi')
-rw-r--r-- | userapi/storage/accounts/sqlite3/account_data_table.go | 14 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/accounts_table.go | 21 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/profile_table.go | 22 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/storage.go | 28 | ||||
-rw-r--r-- | userapi/storage/accounts/sqlite3/threepid_table.go | 22 |
5 files changed, 47 insertions, 60 deletions
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index aee8db6e..f9430c24 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -18,8 +18,6 @@ import ( "context" "database/sql" "encoding/json" - - "github.com/matrix-org/dendrite/internal/sqlutil" ) const accountDataSchema = ` @@ -51,15 +49,13 @@ const selectAccountDataByTypeSQL = "" + type accountDataStatements struct { db *sql.DB - writer sqlutil.Writer insertAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +func (s *accountDataStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = writer _, err = db.Exec(accountDataSchema) if err != nil { return @@ -78,11 +74,9 @@ func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err func (s *accountDataStatements) insertAccountData( ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, -) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) - return err - }) +) error { + _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + return err } func (s *accountDataStatements) selectAccountData( diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 83b90668..798a6de9 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -20,7 +20,6 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/userutil" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -59,7 +58,6 @@ const selectNewNumericLocalpartSQL = "" + type accountsStatements struct { db *sql.DB - writer sqlutil.Writer insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt @@ -67,9 +65,9 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = writer + _, err = db.Exec(accountsSchema) if err != nil { return @@ -99,15 +97,12 @@ func (s *accountsStatements) insertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - var err error - if appserviceID == "" { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) - } else { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) - } - return err - }) + var err error + if appserviceID == "" { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } if err != nil { return nil, err } diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index 1ec45e03..4eeaf037 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -53,7 +53,6 @@ const selectProfilesBySearchSQL = "" + type profilesStatements struct { db *sql.DB - writer sqlutil.Writer insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt @@ -61,9 +60,8 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +func (s *profilesStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = writer _, err = db.Exec(profilesSchema) if err != nil { return @@ -88,11 +86,9 @@ func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err err func (s *profilesStatements) insertProfile( ctx context.Context, txn *sql.Tx, localpart string, -) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") - return err - }) +) error { + _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + return err } func (s *profilesStatements) selectProfileByLocalpart( @@ -109,16 +105,18 @@ func (s *profilesStatements) selectProfileByLocalpart( } func (s *profilesStatements) setAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ) (err error) { - _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) + stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) + _, err = stmt.ExecContext(ctx, avatarURL, localpart) return } func (s *profilesStatements) setDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, localpart string, displayName string, ) (err error) { - _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) + stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) + _, err = stmt.ExecContext(ctx, displayName, localpart) return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4f45f754..46106297 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -64,16 +64,16 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = partitions.Prepare(db, d.writer, "account"); err != nil { return nil, err } - if err = d.accounts.prepare(db, d.writer, serverName); err != nil { + if err = d.accounts.prepare(db, serverName); err != nil { return nil, err } - if err = d.profiles.prepare(db, d.writer); err != nil { + if err = d.profiles.prepare(db); err != nil { return nil, err } - if err = d.accountDatas.prepare(db, d.writer); err != nil { + if err = d.accountDatas.prepare(db); err != nil { return nil, err } - if err = d.threepids.prepare(db, d.writer); err != nil { + if err = d.threepids.prepare(db); err != nil { return nil, err } return d, nil @@ -109,7 +109,9 @@ func (d *Database) SetAvatarURL( ) error { d.profilesMu.Lock() defer d.profilesMu.Unlock() - return d.profiles.setAvatarURL(ctx, localpart, avatarURL) + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL) + }) } // SetDisplayName updates the display name of the profile associated with the given @@ -119,7 +121,9 @@ func (d *Database) SetDisplayName( ) error { d.profilesMu.Lock() defer d.profilesMu.Unlock() - return d.profiles.setDisplayName(ctx, localpart, displayName) + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.profiles.setDisplayName(ctx, txn, localpart, displayName) + }) } // CreateGuestAccount makes a new guest account and creates an empty profile @@ -136,7 +140,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, er defer d.profilesMu.Unlock() defer d.accountDatasMu.Unlock() defer d.accountsMu.Unlock() - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) if err != nil { @@ -162,7 +166,7 @@ func (d *Database) CreateAccount( defer d.profilesMu.Unlock() defer d.accountDatasMu.Unlock() defer d.accountsMu.Unlock() - err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err }) @@ -214,7 +218,7 @@ func (d *Database) SaveAccountData( ) error { d.accountDatasMu.Lock() defer d.accountDatasMu.Unlock() - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) }) } @@ -267,7 +271,7 @@ func (d *Database) SaveThreePIDAssociation( ) (err error) { d.threepidsMu.Lock() defer d.threepidsMu.Unlock() - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { user, err := d.threepids.selectLocalpartForThreePID( ctx, txn, threepid, medium, ) @@ -292,7 +296,9 @@ func (d *Database) RemoveThreePIDAssociation( ) (err error) { d.threepidsMu.Lock() defer d.threepidsMu.Unlock() - return d.threepids.deleteThreePID(ctx, threepid, medium) + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.threepids.deleteThreePID(ctx, txn, threepid, medium) + }) } // GetLocalpartForThreePID looks up the localpart associated with a given third-party diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 230978fe..43112d38 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -54,16 +54,14 @@ const deleteThreePIDSQL = "" + type threepidStatements struct { db *sql.DB - writer sqlutil.Writer selectLocalpartForThreePIDStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt insertThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { +func (s *threepidStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = writer _, err = db.Exec(threepidSchema) if err != nil { return @@ -122,18 +120,14 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) insertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err := stmt.ExecContext(ctx, threepid, medium, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + return err } func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) - _, err := stmt.ExecContext(ctx, threepid, medium) - return err - }) + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) + return err } |