diff options
Diffstat (limited to 'mediaapi/storage/postgres/media_repository_table.go')
-rw-r--r-- | mediaapi/storage/postgres/media_repository_table.go | 33 |
1 files changed, 18 insertions, 15 deletions
diff --git a/mediaapi/storage/postgres/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go index 1d3264ca..41cee487 100644 --- a/mediaapi/storage/postgres/media_repository_table.go +++ b/mediaapi/storage/postgres/media_repository_table.go @@ -20,6 +20,8 @@ import ( "database/sql" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/mediaapi/storage/tables" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -69,24 +71,25 @@ type mediaStatements struct { selectMediaByHashStmt *sql.Stmt } -func (s *mediaStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(mediaSchema) +func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) { + s := &mediaStatements{} + _, err := db.Exec(mediaSchema) if err != nil { - return + return nil, err } - return statementList{ + return s, sqlutil.StatementList{ {&s.insertMediaStmt, insertMediaSQL}, {&s.selectMediaStmt, selectMediaSQL}, {&s.selectMediaByHashStmt, selectMediaByHashSQL}, - }.prepare(db) + }.Prepare(db) } -func (s *mediaStatements) insertMedia( - ctx context.Context, mediaMetadata *types.MediaMetadata, +func (s *mediaStatements) InsertMedia( + ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata, ) error { - mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertMediaStmt.ExecContext( + mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now()) + _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, @@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia( return err } -func (s *mediaStatements) selectMedia( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMedia( + ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ MediaID: mediaID, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext( ctx, mediaMetadata.MediaID, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, @@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia( return &mediaMetadata, err } -func (s *mediaStatements) selectMediaByHash( - ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +func (s *mediaStatements) SelectMediaByHash( + ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, ) (*types.MediaMetadata, error) { mediaMetadata := types.MediaMetadata{ Base64Hash: mediaHash, Origin: mediaOrigin, } - err := s.selectMediaStmt.QueryRowContext( + err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext( ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, ).Scan( &mediaMetadata.ContentType, |