aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-06-18 18:36:03 +0100
committerGitHub <noreply@github.com>2020-06-18 18:36:03 +0100
commitdc0bac85d5bad933d32ee63f8bc1aef6348ca6e9 (patch)
tree78d5e0fc237e0104d525071af83d550d6a314d49 /userapi/storage
parent3547a1768c36626c672e5c7834f297496f568b2f (diff)
Refactor account data (#1150)
* Refactor account data * Tweak database fetching * Tweaks * Restore syncProducer notification * Various tweaks, update tag behaviour * Fix initial sync
Diffstat (limited to 'userapi/storage')
-rw-r--r--userapi/storage/accounts/interface.go7
-rw-r--r--userapi/storage/accounts/postgres/account_data_table.go48
-rw-r--r--userapi/storage/accounts/postgres/storage.go13
-rw-r--r--userapi/storage/accounts/sqlite3/account_data_table.go50
-rw-r--r--userapi/storage/accounts/sqlite3/storage.go13
5 files changed, 58 insertions, 73 deletions
diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go
index 13e3e289..c6692879 100644
--- a/userapi/storage/accounts/interface.go
+++ b/userapi/storage/accounts/interface.go
@@ -16,6 +16,7 @@ package accounts
import (
"context"
+ "encoding/json"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@@ -39,13 +40,13 @@ type Database interface {
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
- SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
- GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
+ SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
+ GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
- GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
+ GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go
index 2f16c5c0..90c79e87 100644
--- a/userapi/storage/accounts/postgres/account_data_table.go
+++ b/userapi/storage/accounts/postgres/account_data_table.go
@@ -17,9 +17,9 @@ package postgres
import (
"context"
"database/sql"
+ "encoding/json"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
- ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
+ ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
@@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
- err error,
+ /* global */ map[string]json.RawMessage,
+ /* rooms */ map[string]map[string]json.RawMessage,
+ error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
- return
+ return nil, nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
- global = []gomatrixserverlib.ClientEvent{}
- rooms = make(map[string][]gomatrixserverlib.ClientEvent)
+ global := map[string]json.RawMessage{}
+ rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
@@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
- return
- }
-
- ac := gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
+ return nil, nil, err
}
- if len(roomID) > 0 {
- rooms[roomID] = append(rooms[roomID], ac)
+ if roomID != "" {
+ if _, ok := rooms[roomID]; !ok {
+ rooms[roomID] = map[string]json.RawMessage{}
+ }
+ rooms[roomID][dataType] = content
} else {
- global = append(global, ac)
+ global[dataType] = content
}
}
+
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
+ var bytes []byte
stmt := s.selectAccountDataByTypeStmt
- var content []byte
-
- if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
+ if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
-
return
}
-
- data = &gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
- }
-
+ data = json.RawMessage(bytes)
return
}
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index 2b88cb70..e5509980 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"strconv"
@@ -169,7 +170,7 @@ func (d *Database) createAccount(
return nil, err
}
- if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
+ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@@ -177,7 +178,7 @@ func (d *Database) createAccount(
"sender": [],
"underride": []
}
- }`); err != nil {
+ }`)); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@@ -295,7 +296,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
- ctx context.Context, localpart, roomID, dataType, content string,
+ ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@@ -306,8 +307,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
+ global map[string]json.RawMessage,
+ rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
@@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go
index b6bb6361..d048dbd1 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/accounts/sqlite3/account_data_table.go
@@ -17,8 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
-
- "github.com/matrix-org/gomatrixserverlib"
+ "encoding/json"
)
const accountDataSchema = `
@@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
- ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
+ ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return
@@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
- err error,
+ /* global */ map[string]json.RawMessage,
+ /* rooms */ map[string]map[string]json.RawMessage,
+ error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
- return
+ return nil, nil, err
}
- global = []gomatrixserverlib.ClientEvent{}
- rooms = make(map[string][]gomatrixserverlib.ClientEvent)
+ global := map[string]json.RawMessage{}
+ rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
@@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
- return
+ return nil, nil, err
}
- ac := gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
- }
-
- if len(roomID) > 0 {
- rooms[roomID] = append(rooms[roomID], ac)
+ if roomID != "" {
+ if _, ok := rooms[roomID]; !ok {
+ rooms[roomID] = map[string]json.RawMessage{}
+ }
+ rooms[roomID][dataType] = content
} else {
- global = append(global, ac)
+ global[dataType] = content
}
}
- return
+ return global, rooms, nil
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
+ var bytes []byte
stmt := s.selectAccountDataByTypeStmt
- var content []byte
-
- if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
+ if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
-
return
}
-
- data = &gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
- }
-
+ data = json.RawMessage(bytes)
return
}
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go
index 4dd755a7..dbf6606c 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/accounts/sqlite3/storage.go
@@ -17,6 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"strconv"
"sync"
@@ -180,7 +181,7 @@ func (d *Database) createAccount(
return nil, err
}
- if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
+ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@@ -188,7 +189,7 @@ func (d *Database) createAccount(
"sender": [],
"underride": []
}
- }`); err != nil {
+ }`)); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@@ -306,7 +307,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
- ctx context.Context, localpart, roomID, dataType, content string,
+ ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
+ global map[string]json.RawMessage,
+ rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
@@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)