1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
|
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/clientapi/api"
internal "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"golang.org/x/exp/constraints"
)
const registrationTokensSchema = `
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
token TEXT PRIMARY KEY,
pending BIGINT,
completed BIGINT,
uses_allowed BIGINT,
expiry_time BIGINT
);
`
const selectTokenSQL = "" +
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
const insertTokenSQL = "" +
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
const listAllTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens"
const listValidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
"(expiry_time > $1 OR expiry_time IS NULL)"
const listInvalidTokensSQL = "" +
"SELECT * FROM userapi_registration_tokens WHERE" +
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
const getTokenSQL = "" +
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
const deleteTokenSQL = "" +
"DELETE FROM userapi_registration_tokens WHERE token = $1"
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
const updateTokenUsesAllowedSQL = "" +
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
const updateTokenExpiryTimeSQL = "" +
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
type registrationTokenStatements struct {
selectTokenStatement *sql.Stmt
insertTokenStatement *sql.Stmt
listAllTokensStatement *sql.Stmt
listValidTokensStatement *sql.Stmt
listInvalidTokenStatement *sql.Stmt
getTokenStatement *sql.Stmt
deleteTokenStatement *sql.Stmt
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
updateTokenUsesAllowedStatement *sql.Stmt
updateTokenExpiryTimeStatement *sql.Stmt
}
func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
s := ®istrationTokenStatements{}
_, err := db.Exec(registrationTokensSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectTokenStatement, selectTokenSQL},
{&s.insertTokenStatement, insertTokenSQL},
{&s.listAllTokensStatement, listAllTokensSQL},
{&s.listValidTokensStatement, listValidTokensSQL},
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
{&s.getTokenStatement, getTokenSQL},
{&s.deleteTokenStatement, deleteTokenSQL},
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
}.Prepare(db)
}
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
var existingToken string
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
_, err := stmt.ExecContext(
ctx,
*registrationToken.Token,
getInsertValue(registrationToken.UsesAllowed),
getInsertValue(registrationToken.ExpiryTime),
*registrationToken.Pending,
*registrationToken.Completed)
if err != nil {
return false, err
}
return true, nil
}
func getInsertValue[t constraints.Integer](in *t) any {
if in == nil {
return nil
}
return *in
}
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
var stmt *sql.Stmt
var tokens []api.RegistrationToken
var tokenString string
var pending, completed, usesAllowed *int32
var expiryTime *int64
var rows *sql.Rows
var err error
if returnAll {
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
rows, err = stmt.QueryContext(ctx)
} else if valid {
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
} else {
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
}
if err != nil {
return tokens, err
}
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
for rows.Next() {
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return tokens, err
}
tokenString := tokenString
pending := pending
completed := completed
usesAllowed := usesAllowed
expiryTime := expiryTime
tokenMap := api.RegistrationToken{
Token: &tokenString,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
tokens = append(tokens, tokenMap)
}
return tokens, rows.Err()
}
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
var pending, completed, usesAllowed *int32
var expiryTime *int64
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
if err != nil {
return nil, err
}
token := api.RegistrationToken{
Token: &tokenString,
Pending: pending,
Completed: completed,
UsesAllowed: usesAllowed,
ExpiryTime: expiryTime,
}
return &token, nil
}
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
_, err := stmt.ExecContext(ctx, tokenString)
if err != nil {
return err
}
return nil
}
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
var stmt *sql.Stmt
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
if usesAllowedPresent && expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
if err != nil {
return nil, err
}
} else if usesAllowedPresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
if err != nil {
return nil, err
}
} else if expiryTimePresent {
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
if err != nil {
return nil, err
}
}
return s.GetRegistrationToken(ctx, tx, tokenString)
}
|