aboutsummaryrefslogtreecommitdiff
path: root/clientapi
diff options
context:
space:
mode:
authorThibaut CHARLES <cromfr@gmail.com>2019-07-24 18:08:51 +0200
committerAlex Chen <Cnly@users.noreply.github.com>2019-07-25 00:08:51 +0800
commitb729a10366f9cb6f8b34db58c7bc1b9b69e67b5f (patch)
treeb79a87f749690facebb458d1daf550ed06285795 /clientapi
parent6773572907a7748ce7f4ccd5467ee2e1d5d06f77 (diff)
Store & retrieve filters as structs rather than []byte (#436)
Manipulate filters as gomatrix.Filter structures, instead of their []byte JSON representation. This lays ground work for using filters in dendrite for /sync requests.
Diffstat (limited to 'clientapi')
-rw-r--r--clientapi/auth/storage/accounts/filter_table.go38
-rw-r--r--clientapi/auth/storage/accounts/storage.go8
-rw-r--r--clientapi/routing/filter.go20
3 files changed, 39 insertions, 27 deletions
diff --git a/clientapi/auth/storage/accounts/filter_table.go b/clientapi/auth/storage/accounts/filter_table.go
index 81bae454..2b07ef17 100644
--- a/clientapi/auth/storage/accounts/filter_table.go
+++ b/clientapi/auth/storage/accounts/filter_table.go
@@ -17,6 +17,7 @@ package accounts
import (
"context"
"database/sql"
+ "encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -71,25 +72,44 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
func (s *filterStatements) selectFilter(
ctx context.Context, localpart string, filterID string,
-) (filter []byte, err error) {
- err = s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filter)
- return
+) (*gomatrixserverlib.Filter, error) {
+ // Retrieve filter from database (stored as canonical JSON)
+ var filterData []byte
+ err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
+ if err != nil {
+ return nil, err
+ }
+
+ // Unmarshal JSON into Filter struct
+ var filter gomatrixserverlib.Filter
+ if err = json.Unmarshal(filterData, &filter); err != nil {
+ return nil, err
+ }
+ return &filter, nil
}
func (s *filterStatements) insertFilter(
- ctx context.Context, filter []byte, localpart string,
+ ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
) (filterID string, err error) {
var existingFilterID string
- // This can result in a race condition when two clients try to insert the
- // same filter and localpart at the same time, however this is not a
- // problem as both calls will result in the same filterID
- filterJSON, err := gomatrixserverlib.CanonicalJSON(filter)
+ // Serialise json
+ filterJSON, err := json.Marshal(filter)
+ if err != nil {
+ return "", err
+ }
+ // Remove whitespaces and sort JSON data
+ // needed to prevent from inserting the same filter multiple times
+ filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
if err != nil {
return "", err
}
- // Check if filter already exists in the database
+ // Check if filter already exists in the database using its localpart and content
+ //
+ // This can result in a race condition when two clients try to insert the
+ // same filter and localpart at the same time, however this is not a
+ // problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
diff --git a/clientapi/auth/storage/accounts/storage.go b/clientapi/auth/storage/accounts/storage.go
index 27c0a176..5c8ffffe 100644
--- a/clientapi/auth/storage/accounts/storage.go
+++ b/clientapi/auth/storage/accounts/storage.go
@@ -344,11 +344,11 @@ func (d *Database) GetThreePIDsForLocalpart(
}
// GetFilter looks up the filter associated with a given local user and filter ID.
-// Returns a filter represented as a byte slice. Otherwise returns an error if
-// no such filter exists or if there was an error talking to the database.
+// Returns a filter structure. Otherwise returns an error if no such filter exists
+// or if there was an error talking to the database.
func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string,
-) ([]byte, error) {
+) (*gomatrixserverlib.Filter, error) {
return d.filter.selectFilter(ctx, localpart, filterID)
}
@@ -356,7 +356,7 @@ func (d *Database) GetFilter(
// Returns the filterID as a string. Otherwise returns an error if something
// goes wrong.
func (d *Database) PutFilter(
- ctx context.Context, localpart string, filter []byte,
+ ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
) (string, error) {
return d.filter.insertFilter(ctx, filter, localpart)
}
diff --git a/clientapi/routing/filter.go b/clientapi/routing/filter.go
index 291a165b..eec501ff 100644
--- a/clientapi/routing/filter.go
+++ b/clientapi/routing/filter.go
@@ -17,13 +17,10 @@ package routing
import (
"net/http"
- "encoding/json"
-
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
- "github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
@@ -43,7 +40,7 @@ func GetFilter(
return httputil.LogThenError(req, err)
}
- res, err := accountDB.GetFilter(req.Context(), localpart, filterID)
+ filter, err := accountDB.GetFilter(req.Context(), localpart, filterID)
if err != nil {
//TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned,
@@ -53,11 +50,6 @@ func GetFilter(
JSON: jsonerror.NotFound("No such filter"),
}
}
- filter := gomatrix.Filter{}
- err = json.Unmarshal(res, &filter)
- if err != nil {
- return httputil.LogThenError(req, err)
- }
return util.JSONResponse{
Code: http.StatusOK,
@@ -85,21 +77,21 @@ func PutFilter(
return httputil.LogThenError(req, err)
}
- var filter gomatrix.Filter
+ var filter gomatrixserverlib.Filter
if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil {
return *reqErr
}
- filterArray, err := json.Marshal(filter)
- if err != nil {
+ // Validate generates a user-friendly error
+ if err = filter.Validate(); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
- JSON: jsonerror.BadJSON("Filter is malformed"),
+ JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()),
}
}
- filterID, err := accountDB.PutFilter(req.Context(), localpart, filterArray)
+ filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter)
if err != nil {
return httputil.LogThenError(req, err)
}