aboutsummaryrefslogtreecommitdiff
path: root/mediaapi/storage/postgres/media_repository_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'mediaapi/storage/postgres/media_repository_table.go')
-rw-r--r--mediaapi/storage/postgres/media_repository_table.go33
1 files changed, 18 insertions, 15 deletions
diff --git a/mediaapi/storage/postgres/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go
index 1d3264ca..41cee487 100644
--- a/mediaapi/storage/postgres/media_repository_table.go
+++ b/mediaapi/storage/postgres/media_repository_table.go
@@ -20,6 +20,8 @@ import (
"database/sql"
"time"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -69,24 +71,25 @@ type mediaStatements struct {
selectMediaByHashStmt *sql.Stmt
}
-func (s *mediaStatements) prepare(db *sql.DB) (err error) {
- _, err = db.Exec(mediaSchema)
+func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
+ s := &mediaStatements{}
+ _, err := db.Exec(mediaSchema)
if err != nil {
- return
+ return nil, err
}
- return statementList{
+ return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
- }.prepare(db)
+ }.Prepare(db)
}
-func (s *mediaStatements) insertMedia(
- ctx context.Context, mediaMetadata *types.MediaMetadata,
+func (s *mediaStatements) InsertMedia(
+ ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error {
- mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
- _, err := s.insertMediaStmt.ExecContext(
+ mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
+ _, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
@@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
return err
}
-func (s *mediaStatements) selectMedia(
- ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
+func (s *mediaStatements) SelectMedia(
+ ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
- err := s.selectMediaStmt.QueryRowContext(
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
@@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err
}
-func (s *mediaStatements) selectMediaByHash(
- ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
+func (s *mediaStatements) SelectMediaByHash(
+ ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash,
Origin: mediaOrigin,
}
- err := s.selectMediaStmt.QueryRowContext(
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,