aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/storage')
-rw-r--r--syncapi/storage/interface.go1
-rw-r--r--syncapi/storage/shared/storage_consumer.go2
-rw-r--r--syncapi/storage/shared/storage_sync.go14
3 files changed, 17 insertions, 0 deletions
diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go
index 4a03aca7..be75f8ad 100644
--- a/syncapi/storage/interface.go
+++ b/syncapi/storage/interface.go
@@ -29,6 +29,7 @@ import (
type DatabaseTransaction interface {
sqlutil.Transaction
+ Reset() (err error)
SharedUsers
MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error)
diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go
index fb3b295e..937ced3a 100644
--- a/syncapi/storage/shared/storage_consumer.go
+++ b/syncapi/storage/shared/storage_consumer.go
@@ -77,6 +77,7 @@ func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransactio
}
return &DatabaseTransaction{
Database: d,
+ ctx: ctx,
txn: txn,
}, nil
*/
@@ -89,6 +90,7 @@ func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransac
}
return &DatabaseTransaction{
Database: d,
+ ctx: ctx,
txn: txn,
}, nil
}
diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go
index a19135a6..6cc83ebc 100644
--- a/syncapi/storage/shared/storage_sync.go
+++ b/syncapi/storage/shared/storage_sync.go
@@ -13,6 +13,7 @@ import (
type DatabaseTransaction struct {
*Database
+ ctx context.Context
txn *sql.Tx
}
@@ -30,6 +31,19 @@ func (d *DatabaseTransaction) Rollback() error {
return d.txn.Rollback()
}
+func (d *DatabaseTransaction) Reset() (err error) {
+ if d.txn == nil {
+ return nil
+ }
+ if err = d.txn.Rollback(); err != nil {
+ return err
+ }
+ if d.txn, err = d.DB.BeginTx(d.ctx, nil); err != nil {
+ return err
+ }
+ return
+}
+
func (d *DatabaseTransaction) MaxStreamPositionForPDUs(ctx context.Context) (types.StreamPosition, error) {
id, err := d.OutputEvents.SelectMaxEventID(ctx, d.txn)
if err != nil {