aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authoroliverpool <3864879+oliverpool@users.noreply.github.com>2020-08-25 14:11:52 +0200
committerGitHub <noreply@github.com>2020-08-25 13:11:52 +0100
commita4db43e0969125db899dae465daf3ab1385c8ce9 (patch)
treea6fa17159a4a5e066b56751c849edba628b710e9 /internal
parentc8b873abc8cb20227774c648b7a774214c8f3752 (diff)
Don't overwrite global err before return (#1293)
Signed-off-by: Olivier Charvin <git@olivier.pfad.fr>
Diffstat (limited to 'internal')
-rw-r--r--internal/sqlutil/partition_offset_table.go24
-rw-r--r--internal/sqlutil/sql.go20
2 files changed, 24 insertions, 20 deletions
diff --git a/internal/sqlutil/partition_offset_table.go b/internal/sqlutil/partition_offset_table.go
index be079442..e19a092f 100644
--- a/internal/sqlutil/partition_offset_table.go
+++ b/internal/sqlutil/partition_offset_table.go
@@ -18,8 +18,6 @@ import (
"context"
"database/sql"
"strings"
-
- "github.com/matrix-org/util"
)
// A PartitionOffset is the offset into a partition of the input log.
@@ -99,26 +97,28 @@ func (s *PartitionOffsetStatements) SetPartitionOffset(
// selectPartitionOffsets returns all the partition offsets for the given topic.
func (s *PartitionOffsetStatements) selectPartitionOffsets(
ctx context.Context, topic string,
-) ([]PartitionOffset, error) {
+) (results []PartitionOffset, err error) {
rows, err := s.selectPartitionOffsetsStmt.QueryContext(ctx, topic)
if err != nil {
return nil, err
}
- defer func() {
- err2 := rows.Close()
- if err2 != nil {
- util.GetLogger(ctx).WithError(err2).Error("selectPartitionOffsets: rows.close() failed")
- }
- }()
- var results []PartitionOffset
+ defer checkNamedErr(rows.Close, &err)
for rows.Next() {
var offset PartitionOffset
- if err := rows.Scan(&offset.Partition, &offset.Offset); err != nil {
+ if err = rows.Scan(&offset.Partition, &offset.Offset); err != nil {
return nil, err
}
results = append(results, offset)
}
- return results, rows.Err()
+ err = rows.Err()
+ return results, err
+}
+
+// checkNamedErr calls fn and overwrite err if it was nil and fn returned non-nil
+func checkNamedErr(fn func() error, err *error) {
+ if e := fn(); e != nil && *err == nil {
+ *err = e
+ }
}
// UpsertPartitionOffset updates or inserts the partition offset for the given topic.
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go
index d296c418..1d2825d5 100644
--- a/internal/sqlutil/sql.go
+++ b/internal/sqlutil/sql.go
@@ -38,9 +38,18 @@ type Transaction interface {
// was applied correctly. For example, 'database is locked' errors in sqlite will happen here.
func EndTransaction(txn Transaction, succeeded *bool) error {
if *succeeded {
- return txn.Commit() // nolint: errcheck
+ return txn.Commit()
} else {
- return txn.Rollback() // nolint: errcheck
+ return txn.Rollback()
+ }
+}
+
+// EndTransactionWithCheck ends a transaction and overwrites the error pointer if its value was nil.
+// If the transaction succeeded then it is committed, otherwise it is rolledback.
+// Designed to be used with defer (see EndTransaction otherwise).
+func EndTransactionWithCheck(txn Transaction, succeeded *bool, err *error) {
+ if e := EndTransaction(txn, succeeded); e != nil && *err == nil {
+ *err = e
}
}
@@ -53,12 +62,7 @@ func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err)
}
succeeded := false
- defer func() {
- err2 := EndTransaction(txn, &succeeded)
- if err == nil && err2 != nil { // failed to commit/rollback
- err = err2
- }
- }()
+ defer EndTransactionWithCheck(txn, &succeeded, &err)
err = fn(txn)
if err != nil {