diff options
Diffstat (limited to 'userapi/storage/postgres/threepid_table.go')
-rw-r--r-- | userapi/storage/postgres/threepid_table.go | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 9280fc87..63c08d61 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -58,12 +59,13 @@ type threepidStatements struct { deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(threepidSchema) +func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { + s := &threepidStatements{} + _, err := db.Exec(threepidSchema) if err != nil { - return + return nil, err } - return sqlutil.StatementList{ + return s, sqlutil.StatementList{ {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL}, @@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) { }.Prepare(db) } -func (s *threepidStatements) selectLocalpartForThreePID( +func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, ) (localpart string, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) @@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID( return } -func (s *threepidStatements) selectThreePIDsForLocalpart( +func (s *threepidStatements) SelectThreePIDsForLocalpart( ctx context.Context, localpart string, ) (threepids []authtypes.ThreePID, err error) { rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) @@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart( return } -func (s *threepidStatements) insertThreePID( +func (s *threepidStatements) InsertThreePID( ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) @@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID( return } -func (s *threepidStatements) deleteThreePID( - ctx context.Context, threepid string, medium string) (err error) { - _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) +func (s *threepidStatements) DeleteThreePID( + ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { + stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) + _, err = stmt.ExecContext(ctx, threepid, medium) return } |