aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/sqlite3/stream_id_table.go
blob: 260f7a95d7a1d0d0d4f31e6a8acf65d989f73e5f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package sqlite3

import (
	"context"
	"database/sql"

	"github.com/matrix-org/dendrite/common"
	"github.com/matrix-org/dendrite/syncapi/types"
)

const streamIDTableSchema = `
-- Global stream ID counter, used by other tables.
CREATE TABLE IF NOT EXISTS syncapi_stream_id (
  stream_name TEXT NOT NULL PRIMARY KEY,
  stream_id INT DEFAULT 0,

  UNIQUE(stream_name)
);
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
  ON CONFLICT DO NOTHING;
`

const increaseStreamIDStmt = "" +
	"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"

const selectStreamIDStmt = "" +
	"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"

type streamIDStatements struct {
	increaseStreamIDStmt *sql.Stmt
	selectStreamIDStmt   *sql.Stmt
}

func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
	_, err = db.Exec(streamIDTableSchema)
	if err != nil {
		return
	}
	if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
		return
	}
	if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
		return
	}
	return
}

func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
	increaseStmt := common.TxStmt(txn, s.increaseStreamIDStmt)
	selectStmt := common.TxStmt(txn, s.selectStreamIDStmt)
	if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
		return
	}
	if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
		return
	}
	return
}