diff options
Diffstat (limited to 'internal/sqlutil/sql.go')
-rw-r--r-- | internal/sqlutil/sql.go | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index d296c418..1d2825d5 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -38,9 +38,18 @@ type Transaction interface { // was applied correctly. For example, 'database is locked' errors in sqlite will happen here. func EndTransaction(txn Transaction, succeeded *bool) error { if *succeeded { - return txn.Commit() // nolint: errcheck + return txn.Commit() } else { - return txn.Rollback() // nolint: errcheck + return txn.Rollback() + } +} + +// EndTransactionWithCheck ends a transaction and overwrites the error pointer if its value was nil. +// If the transaction succeeded then it is committed, otherwise it is rolledback. +// Designed to be used with defer (see EndTransaction otherwise). +func EndTransactionWithCheck(txn Transaction, succeeded *bool, err *error) { + if e := EndTransaction(txn, succeeded); e != nil && *err == nil { + *err = e } } @@ -53,12 +62,7 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err) } succeeded := false - defer func() { - err2 := EndTransaction(txn, &succeeded) - if err == nil && err2 != nil { // failed to commit/rollback - err = err2 - } - }() + defer EndTransactionWithCheck(txn, &succeeded, &err) err = fn(txn) if err != nil { |