diff options
author | oliverpool <3864879+oliverpool@users.noreply.github.com> | 2020-08-25 14:11:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-25 13:11:52 +0100 |
commit | a4db43e0969125db899dae465daf3ab1385c8ce9 (patch) | |
tree | a6fa17159a4a5e066b56751c849edba628b710e9 /internal | |
parent | c8b873abc8cb20227774c648b7a774214c8f3752 (diff) |
Don't overwrite global err before return (#1293)
Signed-off-by: Olivier Charvin <git@olivier.pfad.fr>
Diffstat (limited to 'internal')
-rw-r--r-- | internal/sqlutil/partition_offset_table.go | 24 | ||||
-rw-r--r-- | internal/sqlutil/sql.go | 20 |
2 files changed, 24 insertions, 20 deletions
diff --git a/internal/sqlutil/partition_offset_table.go b/internal/sqlutil/partition_offset_table.go index be079442..e19a092f 100644 --- a/internal/sqlutil/partition_offset_table.go +++ b/internal/sqlutil/partition_offset_table.go @@ -18,8 +18,6 @@ import ( "context" "database/sql" "strings" - - "github.com/matrix-org/util" ) // A PartitionOffset is the offset into a partition of the input log. @@ -99,26 +97,28 @@ func (s *PartitionOffsetStatements) SetPartitionOffset( // selectPartitionOffsets returns all the partition offsets for the given topic. func (s *PartitionOffsetStatements) selectPartitionOffsets( ctx context.Context, topic string, -) ([]PartitionOffset, error) { +) (results []PartitionOffset, err error) { rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic) if err != nil { return nil, err } - defer func() { - err2 := rows.Close() - if err2 != nil { - util.GetLogger(ctx).WithError(err2).Error("selectPartitionOffsets: rows.close() failed") - } - }() - var results []PartitionOffset + defer checkNamedErr(rows.Close, &err) for rows.Next() { var offset PartitionOffset - if err := rows.Scan(&offset.Partition, &offset.Offset); err != nil { + if err = rows.Scan(&offset.Partition, &offset.Offset); err != nil { return nil, err } results = append(results, offset) } - return results, rows.Err() + err = rows.Err() + return results, err +} + +// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil +func checkNamedErr(fn func() error, err *error) { + if e := fn(); e != nil && *err == nil { + *err = e + } } // UpsertPartitionOffset updates or inserts the partition offset for the given topic. diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index d296c418..1d2825d5 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -38,9 +38,18 @@ type Transaction interface { // was applied correctly. For example, 'database is locked' errors in sqlite will happen here. func EndTransaction(txn Transaction, succeeded *bool) error { if *succeeded { - return txn.Commit() // nolint: errcheck + return txn.Commit() } else { - return txn.Rollback() // nolint: errcheck + return txn.Rollback() + } +} + +// EndTransactionWithCheck ends a transaction and overwrites the error pointer if its value was nil. +// If the transaction succeeded then it is committed, otherwise it is rolledback. +// Designed to be used with defer (see EndTransaction otherwise). +func EndTransactionWithCheck(txn Transaction, succeeded *bool, err *error) { + if e := EndTransaction(txn, succeeded); e != nil && *err == nil { + *err = e } } @@ -53,12 +62,7 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) { return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err) } succeeded := false - defer func() { - err2 := EndTransaction(txn, &succeeded) - if err == nil && err2 != nil { // failed to commit/rollback - err = err2 - } - }() + defer EndTransactionWithCheck(txn, &succeeded, &err) err = fn(txn) if err != nil { |