aboutsummaryrefslogtreecommitdiff
path: root/setup/mscs/msc2836/storage.go
blob: 6a45f08a4e5dfed308aaac78bf49e7e270dc6ed8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
package msc2836

import (
	"bytes"
	"context"
	"database/sql"
	"encoding/base64"
	"encoding/json"

	"github.com/matrix-org/dendrite/internal/sqlutil"
	"github.com/matrix-org/dendrite/roomserver/types"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/gomatrixserverlib/spec"
	"github.com/matrix-org/util"
)

type eventInfo struct {
	EventID        string
	OriginServerTS spec.Timestamp
	RoomID         string
}

type Database interface {
	// StoreRelation stores the parent->child and child->parent relationship for later querying.
	// Also stores the event metadata e.g timestamp
	StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error
	// ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the
	// provided `relType`. The returned slice is sorted by origin_server_ts according to whether
	// `recentFirst` is true or false.
	ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
	// ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if
	// there is no parent for this child event, with no error. The parent eventInfo can be missing the
	// timestamp if the event is not known to the server.
	ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error)
	// UpdateChildMetadata persists the children_count and children_hash from this event if and only if
	// the count is greater than what was previously there. If the count is updated, the event will be
	// updated to be unexplored.
	UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error
	// ChildMetadata returns the children_count and children_hash for the event ID in question.
	// Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set
	// back to `false` when a larger count is inserted via UpdateChildMetadata.
	// Returns nil error if the event ID does not exist.
	ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error)
	// MarkChildrenExplored sets the 'explored' flag on this event to `true`.
	MarkChildrenExplored(ctx context.Context, eventID string) error
}

type DB struct {
	db                                     *sql.DB
	writer                                 sqlutil.Writer
	insertEdgeStmt                         *sql.Stmt
	insertNodeStmt                         *sql.Stmt
	selectChildrenForParentOldestFirstStmt *sql.Stmt
	selectChildrenForParentRecentFirstStmt *sql.Stmt
	selectParentForChildStmt               *sql.Stmt
	updateChildMetadataStmt                *sql.Stmt
	selectChildMetadataStmt                *sql.Stmt
	updateChildMetadataExploredStmt        *sql.Stmt
}

// NewDatabase loads the database for msc2836
func NewDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
	if dbOpts.ConnectionString.IsPostgres() {
		return newPostgresDatabase(conMan, dbOpts)
	}
	return newSQLiteDatabase(conMan, dbOpts)
}

func newPostgresDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
	d := DB{}
	var err error
	if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
		return nil, err
	}
	_, err = d.db.Exec(`
	CREATE TABLE IF NOT EXISTS msc2836_edges (
		parent_event_id TEXT NOT NULL,
		child_event_id TEXT NOT NULL,
		rel_type TEXT NOT NULL,
		parent_room_id TEXT NOT NULL,
		parent_servers TEXT NOT NULL,
		CONSTRAINT msc2836_edges_uniq UNIQUE (parent_event_id, child_event_id, rel_type)
	);

	CREATE TABLE IF NOT EXISTS msc2836_nodes (
		event_id TEXT PRIMARY KEY NOT NULL,
		origin_server_ts BIGINT NOT NULL,
		room_id TEXT NOT NULL,
		unsigned_children_count BIGINT NOT NULL,
		unsigned_children_hash TEXT NOT NULL,
		explored SMALLINT NOT NULL
	);
	`)
	if err != nil {
		return nil, err
	}
	if d.insertEdgeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
		VALUES($1, $2, $3, $4, $5)
		ON CONFLICT DO NOTHING
	`); err != nil {
		return nil, err
	}
	if d.insertNodeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
		VALUES($1, $2, $3, $4, $5, $6)
		ON CONFLICT DO NOTHING
	`); err != nil {
		return nil, err
	}
	selectChildrenQuery := `
	SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
	LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
	WHERE parent_event_id = $1 AND rel_type = $2
	ORDER BY origin_server_ts
	`
	if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
		return nil, err
	}
	if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
		return nil, err
	}
	if d.selectParentForChildStmt, err = d.db.Prepare(`
		SELECT parent_event_id, parent_room_id FROM msc2836_edges
		WHERE child_event_id = $1 AND rel_type = $2
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
	`); err != nil {
		return nil, err
	}
	if d.selectChildMetadataStmt, err = d.db.Prepare(`
		SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
	`); err != nil {
		return nil, err
	}
	return &d, err
}

func newSQLiteDatabase(conMan sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
	d := DB{}
	var err error
	if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
		return nil, err
	}
	_, err = d.db.Exec(`
	CREATE TABLE IF NOT EXISTS msc2836_edges (
		parent_event_id TEXT NOT NULL,
		child_event_id TEXT NOT NULL,
		rel_type TEXT NOT NULL,
		parent_room_id TEXT NOT NULL,
		parent_servers TEXT NOT NULL,
		UNIQUE (parent_event_id, child_event_id, rel_type)
	);

	CREATE TABLE IF NOT EXISTS msc2836_nodes (
		event_id TEXT PRIMARY KEY NOT NULL,
		origin_server_ts BIGINT NOT NULL,
		room_id TEXT NOT NULL,
		unsigned_children_count BIGINT NOT NULL,
		unsigned_children_hash TEXT NOT NULL,
		explored SMALLINT NOT NULL
	);
	`)
	if err != nil {
		return nil, err
	}
	if d.insertEdgeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
		VALUES($1, $2, $3, $4, $5)
		ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
	`); err != nil {
		return nil, err
	}
	if d.insertNodeStmt, err = d.db.Prepare(`
		INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
		VALUES($1, $2, $3, $4, $5, $6)
		ON CONFLICT DO NOTHING
	`); err != nil {
		return nil, err
	}
	selectChildrenQuery := `
	SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
	LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
	WHERE parent_event_id = $1 AND rel_type = $2
	ORDER BY origin_server_ts
	`
	if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
		return nil, err
	}
	if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
		return nil, err
	}
	if d.selectParentForChildStmt, err = d.db.Prepare(`
		SELECT parent_event_id, parent_room_id FROM msc2836_edges
		WHERE child_event_id = $1 AND rel_type = $2
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
	`); err != nil {
		return nil, err
	}
	if d.selectChildMetadataStmt, err = d.db.Prepare(`
		SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
	`); err != nil {
		return nil, err
	}
	if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
		UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
	`); err != nil {
		return nil, err
	}
	return &d, nil
}

func (p *DB) StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error {
	parent, child, relType := parentChildEventIDs(ev)
	if parent == "" || child == "" {
		return nil
	}
	relationRoomID, relationServers := roomIDAndServers(ev)
	relationServersJSON, err := json.Marshal(relationServers)
	if err != nil {
		return err
	}
	count, hash := extractChildMetadata(ev)
	return p.writer.Do(p.db, nil, func(txn *sql.Tx) error {
		_, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON))
		if err != nil {
			return err
		}
		util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType)
		_, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID(), count, base64.RawStdEncoding.EncodeToString(hash), 0)
		return err
	})
}

func (p *DB) UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error {
	eventCount, eventHash := extractChildMetadata(ev)
	if eventCount == 0 {
		return nil // nothing to update with
	}

	// extract current children count/hash, if they are less than the current event then update the columns and set to unexplored
	count, hash, _, err := p.ChildMetadata(ctx, ev.EventID())
	if err != nil {
		return err
	}
	if eventCount > count || (eventCount == count && !bytes.Equal(hash, eventHash)) {
		_, err = p.updateChildMetadataStmt.ExecContext(ctx, eventCount, base64.RawStdEncoding.EncodeToString(eventHash), 0, ev.EventID())
		return err
	}
	return nil
}

func (p *DB) ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) {
	var b64hash string
	var exploredInt int
	if err = p.selectChildMetadataStmt.QueryRowContext(ctx, eventID).Scan(&count, &b64hash, &exploredInt); err != nil {
		if err == sql.ErrNoRows {
			err = nil
		}
		return
	}
	hash, err = base64.RawStdEncoding.DecodeString(b64hash)
	explored = exploredInt > 0
	return
}

func (p *DB) MarkChildrenExplored(ctx context.Context, eventID string) error {
	_, err := p.updateChildMetadataExploredStmt.ExecContext(ctx, 1, eventID)
	return err
}

func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) {
	var rows *sql.Rows
	var err error
	if recentFirst {
		rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType)
	} else {
		rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType)
	}
	if err != nil {
		return nil, err
	}
	defer rows.Close() // nolint: errcheck
	var children []eventInfo
	for rows.Next() {
		var evInfo eventInfo
		if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil {
			return nil, err
		}
		children = append(children, evInfo)
	}
	return children, nil
}

func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) {
	var ei eventInfo
	err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID)
	if err == sql.ErrNoRows {
		return nil, nil
	} else if err != nil {
		return nil, err
	}
	return &ei, nil
}

func parentChildEventIDs(ev *types.HeaderedEvent) (parent, child, relType string) {
	if ev == nil {
		return
	}
	body := struct {
		Relationship struct {
			RelType string `json:"rel_type"`
			EventID string `json:"event_id"`
		} `json:"m.relationship"`
	}{}
	if err := json.Unmarshal(ev.Content(), &body); err != nil {
		return
	}
	if body.Relationship.EventID == "" || body.Relationship.RelType == "" {
		return
	}
	return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType
}

func roomIDAndServers(ev *types.HeaderedEvent) (roomID string, servers []string) {
	servers = []string{}
	if ev == nil {
		return
	}
	body := struct {
		RoomID  string   `json:"relationship_room_id"`
		Servers []string `json:"relationship_servers"`
	}{}
	if err := json.Unmarshal(ev.Unsigned(), &body); err != nil {
		return
	}
	return body.RoomID, body.Servers
}

func extractChildMetadata(ev *types.HeaderedEvent) (count int, hash []byte) {
	unsigned := struct {
		Counts map[string]int   `json:"children"`
		Hash   spec.Base64Bytes `json:"children_hash"`
	}{}
	if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil {
		// expected if there is no unsigned field at all
		return
	}
	for _, c := range unsigned.Counts {
		count += c
	}
	hash = unsigned.Hash
	return
}