diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-06-18 18:36:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-18 18:36:03 +0100 |
commit | dc0bac85d5bad933d32ee63f8bc1aef6348ca6e9 (patch) | |
tree | 78d5e0fc237e0104d525071af83d550d6a314d49 /clientapi | |
parent | 3547a1768c36626c672e5c7834f297496f568b2f (diff) |
Refactor account data (#1150)
* Refactor account data
* Tweak database fetching
* Tweaks
* Restore syncProducer notification
* Various tweaks, update tag behaviour
* Fix initial sync
Diffstat (limited to 'clientapi')
-rw-r--r-- | clientapi/routing/account_data.go | 55 | ||||
-rw-r--r-- | clientapi/routing/room_tagging.go | 125 | ||||
-rw-r--r-- | clientapi/routing/routing.go | 14 |
3 files changed, 87 insertions, 107 deletions
diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 68e0dc5d..d5fafedb 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -16,21 +16,20 @@ package routing import ( "encoding/json" + "fmt" "io/ioutil" "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -40,18 +39,28 @@ func GetAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + } + dataRes := api.QueryAccountDataResponse{} + if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err)) } - if data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, dataType, - ); err == nil { + var data json.RawMessage + var ok bool + if roomID != "" { + data, ok = dataRes.RoomAccountData[roomID][dataType] + } else { + data, ok = dataRes.GlobalAccountData[dataType] + } + if ok { return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: data, } } @@ -63,7 +72,7 @@ func GetAccountData( // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { @@ -73,12 +82,6 @@ func SaveAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - defer req.Body.Close() // nolint: errcheck if req.Body == http.NoBody { @@ -101,13 +104,19 @@ func SaveAccountData( } } - if err := accountDB.SaveAccountData( - req.Context(), localpart, roomID, dataType, string(body), - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed") - return jsonerror.InternalServerError() + dataReq := api.InputAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + AccountData: json.RawMessage(body), + } + dataRes := api.InputAccountDataResponse{} + if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(err) } + // TODO: user API should do this since it's account data if err := syncProducer.SendData(userID, roomID, dataType); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index b1cfcca8..c683cc94 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -24,23 +24,14 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -// newTag creates and returns a new gomatrix.TagContent -func newTag() gomatrix.TagContent { - return gomatrix.TagContent{ - Tags: make(map[string]gomatrix.TagProperties), - } -} - // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags func GetTags( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -54,22 +45,15 @@ func GetTags( } } - _, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - if data == nil { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: tagContent, } } @@ -78,7 +62,7 @@ func GetTags( // the tag to the "map" and saving the new "map" to the DB func PutTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -98,34 +82,25 @@ func PutTag( return *reqErr } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - var tagContent gomatrix.TagContent - if data != nil { - if err = json.Unmarshal(data.Content, &tagContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - } else { - tagContent = newTag() + if tagContent.Tags == nil { + tagContent.Tags = make(map[string]gomatrix.TagProperties) } tagContent.Tags[tag] = properties - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -138,7 +113,7 @@ func PutTag( // the "map" and then saving the new "map" in the DB func DeleteTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -153,28 +128,12 @@ func DeleteTag( } } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - // If there are no tags in the database, exit - if data == nil { - // Spec only defines 200 responses for this endpoint so we don't return anything else. - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - - var tagContent gomatrix.TagContent - err = json.Unmarshal(data.Content, &tagContent) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - // Check whether the tag to be deleted exists if _, ok := tagContent.Tags[tag]; ok { delete(tagContent.Tags, tag) @@ -185,18 +144,16 @@ func DeleteTag( JSON: struct{}{}, } } - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + // TODO: user API should do this since it's account data + if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -210,32 +167,46 @@ func obtainSavedTags( req *http.Request, userID string, roomID string, - accountDB accounts.Database, -) (string, *gomatrixserverlib.ClientEvent, error) { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + userAPI api.UserInternalAPI, +) (tags gomatrix.TagContent, err error) { + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", + } + dataRes := api.QueryAccountDataResponse{} + err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes) if err != nil { - return "", nil, err + return } - - data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, "m.tag", - ) - - return localpart, data, err + data, ok := dataRes.RoomAccountData[roomID]["m.tag"] + if !ok { + return + } + if err = json.Unmarshal(data, &tags); err != nil { + return + } + return tags, nil } // saveTagData saves the provided tag data into the database func saveTagData( req *http.Request, - localpart string, + userID string, roomID string, - accountDB accounts.Database, + userAPI api.UserInternalAPI, Tag gomatrix.TagContent, ) error { newTagData, err := json.Marshal(Tag) if err != nil { return err } - - return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) + dataReq := api.InputAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", + AccountData: json.RawMessage(newTagData), + } + dataRes := api.InputAccountDataResponse{} + return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 41c7fb18..e91b07ac 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -476,7 +476,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -486,7 +486,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -496,7 +496,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"]) }), ).Methods(http.MethodGet) @@ -506,7 +506,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"]) }), ).Methods(http.MethodGet) @@ -604,7 +604,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) + return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -614,7 +614,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -624,7 +624,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodDelete, http.MethodOptions) |