aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-06-18 18:36:03 +0100
committerGitHub <noreply@github.com>2020-06-18 18:36:03 +0100
commitdc0bac85d5bad933d32ee63f8bc1aef6348ca6e9 (patch)
tree78d5e0fc237e0104d525071af83d550d6a314d49 /userapi
parent3547a1768c36626c672e5c7834f297496f568b2f (diff)
Refactor account data (#1150)
* Refactor account data * Tweak database fetching * Tweaks * Restore syncProducer notification * Various tweaks, update tag behaviour * Fix initial sync
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go27
-rw-r--r--userapi/internal/api.go31
-rw-r--r--userapi/inthttp/client.go10
-rw-r--r--userapi/storage/accounts/interface.go7
-rw-r--r--userapi/storage/accounts/postgres/account_data_table.go48
-rw-r--r--userapi/storage/accounts/postgres/storage.go13
-rw-r--r--userapi/storage/accounts/sqlite3/account_data_table.go50
-rw-r--r--userapi/storage/accounts/sqlite3/storage.go13
8 files changed, 112 insertions, 87 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index c953a5ba..a80adf2d 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -16,12 +16,14 @@ package api
import (
"context"
+ "encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
// UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface {
+ InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@@ -30,6 +32,18 @@ type UserInternalAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
}
+// InputAccountDataRequest is the request for InputAccountData
+type InputAccountDataRequest struct {
+ UserID string // required: the user to set account data for
+ RoomID string // optional: the room to associate the account data with
+ DataType string // optional: the data type of the data
+ AccountData json.RawMessage // required: the message content
+}
+
+// InputAccountDataResponse is the response for InputAccountData
+type InputAccountDataResponse struct {
+}
+
// QueryAccessTokenRequest is the request for QueryAccessToken
type QueryAccessTokenRequest struct {
AccessToken string
@@ -46,18 +60,15 @@ type QueryAccessTokenResponse struct {
// QueryAccountDataRequest is the request for QueryAccountData
type QueryAccountDataRequest struct {
- UserID string // required: the user to get account data for.
- // TODO: This is a terribly confusing API shape :/
- DataType string // optional: if specified returns only a single event matching this data type.
- // optional: Only used if DataType is set. If blank returns global account data matching the data type.
- // If set, returns only room account data matching this data type.
- RoomID string
+ UserID string // required: the user to get account data for.
+ RoomID string // optional: the room ID, or global account data if not specified.
+ DataType string // optional: the data type, or all types if not specified.
}
// QueryAccountDataResponse is the response for QueryAccountData
type QueryAccountDataResponse struct {
- GlobalAccountData []gomatrixserverlib.ClientEvent
- RoomAccountData map[string][]gomatrixserverlib.ClientEvent
+ GlobalAccountData map[string]json.RawMessage // type -> data
+ RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data
}
// QueryDevicesRequest is the request for QueryDevices
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index ae021f57..b081eca4 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -17,6 +17,7 @@ package internal
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"fmt"
@@ -38,6 +39,20 @@ type UserInternalAPI struct {
AppServices []config.ApplicationService
}
+func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
+ local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
+ if err != nil {
+ return err
+ }
+ if domain != a.ServerName {
+ return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
+ }
+ if req.DataType == "" {
+ return fmt.Errorf("data type must not be empty")
+ }
+ return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
+}
+
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
if req.AccountType == api.AccountTypeGuest {
acc, err := a.AccountDB.CreateGuestAccount(ctx)
@@ -130,17 +145,21 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName)
}
if req.DataType != "" {
- var event *gomatrixserverlib.ClientEvent
- event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
+ var data json.RawMessage
+ data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil {
return err
}
- if event != nil {
+ res.RoomAccountData = make(map[string]map[string]json.RawMessage)
+ res.GlobalAccountData = make(map[string]json.RawMessage)
+ if data != nil {
if req.RoomID != "" {
- res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent)
- res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event}
+ if _, ok := res.RoomAccountData[req.RoomID]; !ok {
+ res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage)
+ }
+ res.RoomAccountData[req.RoomID][req.DataType] = data
} else {
- res.GlobalAccountData = append(res.GlobalAccountData, *event)
+ res.GlobalAccountData[req.DataType] = data
}
}
return nil
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 0e9628c5..4ab0d690 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -26,6 +26,8 @@ import (
// HTTP paths for the internal HTTP APIs
const (
+ InputAccountDataPath = "/userapi/inputAccountData"
+
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
@@ -55,6 +57,14 @@ type httpUserInternalAPI struct {
httpClient *http.Client
}
+func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
+ defer span.Finish()
+
+ apiURL := h.apiURL + InputAccountDataPath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
+}
+
func (h *httpUserInternalAPI) PerformAccountCreation(
ctx context.Context,
request *api.PerformAccountCreationRequest,
diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go
index 13e3e289..c6692879 100644
--- a/userapi/storage/accounts/interface.go
+++ b/userapi/storage/accounts/interface.go
@@ -16,6 +16,7 @@ package accounts
import (
"context"
+ "encoding/json"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@@ -39,13 +40,13 @@ type Database interface {
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
- SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
- GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
+ SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
+ GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
- GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
+ GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go
index 2f16c5c0..90c79e87 100644
--- a/userapi/storage/accounts/postgres/account_data_table.go
+++ b/userapi/storage/accounts/postgres/account_data_table.go
@@ -17,9 +17,9 @@ package postgres
import (
"context"
"database/sql"
+ "encoding/json"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
- ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
+ ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
@@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
- err error,
+ /* global */ map[string]json.RawMessage,
+ /* rooms */ map[string]map[string]json.RawMessage,
+ error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
- return
+ return nil, nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
- global = []gomatrixserverlib.ClientEvent{}
- rooms = make(map[string][]gomatrixserverlib.ClientEvent)
+ global := map[string]json.RawMessage{}
+ rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
@@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
- return
- }
-
- ac := gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
+ return nil, nil, err
}
- if len(roomID) > 0 {
- rooms[roomID] = append(rooms[roomID], ac)
+ if roomID != "" {
+ if _, ok := rooms[roomID]; !ok {
+ rooms[roomID] = map[string]json.RawMessage{}
+ }
+ rooms[roomID][dataType] = content
} else {
- global = append(global, ac)
+ global[dataType] = content
}
}
+
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
+ var bytes []byte
stmt := s.selectAccountDataByTypeStmt
- var content []byte
-
- if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
+ if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
-
return
}
-
- data = &gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
- }
-
+ data = json.RawMessage(bytes)
return
}
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index 2b88cb70..e5509980 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"strconv"
@@ -169,7 +170,7 @@ func (d *Database) createAccount(
return nil, err
}
- if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
+ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@@ -177,7 +178,7 @@ func (d *Database) createAccount(
"sender": [],
"underride": []
}
- }`); err != nil {
+ }`)); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@@ -295,7 +296,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
- ctx context.Context, localpart, roomID, dataType, content string,
+ ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@@ -306,8 +307,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
+ global map[string]json.RawMessage,
+ rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
@@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go
index b6bb6361..d048dbd1 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/accounts/sqlite3/account_data_table.go
@@ -17,8 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
-
- "github.com/matrix-org/gomatrixserverlib"
+ "encoding/json"
)
const accountDataSchema = `
@@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
- ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
+ ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return
@@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
- err error,
+ /* global */ map[string]json.RawMessage,
+ /* rooms */ map[string]map[string]json.RawMessage,
+ error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
- return
+ return nil, nil, err
}
- global = []gomatrixserverlib.ClientEvent{}
- rooms = make(map[string][]gomatrixserverlib.ClientEvent)
+ global := map[string]json.RawMessage{}
+ rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
@@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
- return
+ return nil, nil, err
}
- ac := gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
- }
-
- if len(roomID) > 0 {
- rooms[roomID] = append(rooms[roomID], ac)
+ if roomID != "" {
+ if _, ok := rooms[roomID]; !ok {
+ rooms[roomID] = map[string]json.RawMessage{}
+ }
+ rooms[roomID][dataType] = content
} else {
- global = append(global, ac)
+ global[dataType] = content
}
}
- return
+ return global, rooms, nil
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
+ var bytes []byte
stmt := s.selectAccountDataByTypeStmt
- var content []byte
-
- if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
+ if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
-
return
}
-
- data = &gomatrixserverlib.ClientEvent{
- Type: dataType,
- Content: content,
- }
-
+ data = json.RawMessage(bytes)
return
}
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go
index 4dd755a7..dbf6606c 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/accounts/sqlite3/storage.go
@@ -17,6 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
+ "encoding/json"
"errors"
"strconv"
"sync"
@@ -180,7 +181,7 @@ func (d *Database) createAccount(
return nil, err
}
- if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
+ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@@ -188,7 +189,7 @@ func (d *Database) createAccount(
"sender": [],
"underride": []
}
- }`); err != nil {
+ }`)); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@@ -306,7 +307,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
- ctx context.Context, localpart, roomID, dataType, content string,
+ ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
- global []gomatrixserverlib.ClientEvent,
- rooms map[string][]gomatrixserverlib.ClientEvent,
+ global map[string]json.RawMessage,
+ rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
@@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
-) (data *gomatrixserverlib.ClientEvent, err error) {
+) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)