diff options
author | Till <2353100+S7evinK@users.noreply.github.com> | 2022-10-17 14:48:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-17 14:48:35 +0200 |
commit | 07bfb791ca616bd3a4aa96691b74c96146d59d90 (patch) | |
tree | 1556cfeea59114ec94ca0906f5eb9db6b136056b /internal/transactions | |
parent | d72d4f8d5d0016a8dcbf77aba92671f3469eb630 (diff) |
Scope transactions to endpoints (#2799)
To avoid returning results from e.g. `/redact` on `/sendToDevice`
requests.
Takes the raw URL path and uses `filepath.Dir` to remove the `txnID`
(file) from it.
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
Diffstat (limited to 'internal/transactions')
-rw-r--r-- | internal/transactions/transactions.go | 16 | ||||
-rw-r--r-- | internal/transactions/transactions_test.go | 42 |
2 files changed, 44 insertions, 14 deletions
diff --git a/internal/transactions/transactions.go b/internal/transactions/transactions.go index d2eb0f27..7ff6f504 100644 --- a/internal/transactions/transactions.go +++ b/internal/transactions/transactions.go @@ -13,6 +13,8 @@ package transactions import ( + "net/url" + "path/filepath" "sync" "time" @@ -29,6 +31,7 @@ type txnsMap map[CacheKey]*util.JSONResponse type CacheKey struct { AccessToken string TxnID string + Endpoint string } // Cache represents a temporary store for response entries. @@ -57,14 +60,14 @@ func NewWithCleanupPeriod(cleanupPeriod time.Duration) *Cache { return &t } -// FetchTransaction looks up an entry for the (accessToken, txnID) tuple in Cache. +// FetchTransaction looks up an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Looks in both the txnMaps. // Returns (JSON response, true) if txnID is found, else the returned bool is false. -func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, bool) { +func (t *Cache) FetchTransaction(accessToken, txnID string, u *url.URL) (*util.JSONResponse, bool) { t.RLock() defer t.RUnlock() for _, txns := range t.txnsMaps { - res, ok := txns[CacheKey{accessToken, txnID}] + res, ok := txns[CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] if ok { return res, true } @@ -72,13 +75,12 @@ func (t *Cache) FetchTransaction(accessToken, txnID string) (*util.JSONResponse, return nil, false } -// AddTransaction adds an entry for the (accessToken, txnID) tuple in Cache. +// AddTransaction adds an entry for the (accessToken, txnID, req.URL) tuple in Cache. // Adds to the front txnMap. -func (t *Cache) AddTransaction(accessToken, txnID string, res *util.JSONResponse) { +func (t *Cache) AddTransaction(accessToken, txnID string, u *url.URL, res *util.JSONResponse) { t.Lock() defer t.Unlock() - - t.txnsMaps[0][CacheKey{accessToken, txnID}] = res + t.txnsMaps[0][CacheKey{accessToken, txnID, filepath.Dir(u.Path)}] = res } // cacheCleanService is responsible for cleaning up entries after cleanupPeriod. diff --git a/internal/transactions/transactions_test.go b/internal/transactions/transactions_test.go index aa837f76..c552550a 100644 --- a/internal/transactions/transactions_test.go +++ b/internal/transactions/transactions_test.go @@ -14,6 +14,9 @@ package transactions import ( "net/http" + "net/url" + "path/filepath" + "reflect" "strconv" "testing" @@ -24,6 +27,16 @@ type fakeType struct { ID string `json:"ID"` } +func TestCompare(t *testing.T) { + u1, _ := url.Parse("/send/1?accessToken=123") + u2, _ := url.Parse("/send/1") + c1 := CacheKey{"1", "2", filepath.Dir(u1.Path)} + c2 := CacheKey{"1", "2", filepath.Dir(u2.Path)} + if !reflect.DeepEqual(c1, c2) { + t.Fatalf("Cache keys differ: %+v <> %+v", c1, c2) + } +} + var ( fakeAccessToken = "aRandomAccessToken" fakeAccessToken2 = "anotherRandomAccessToken" @@ -34,23 +47,28 @@ var ( fakeResponse2 = &util.JSONResponse{ Code: http.StatusOK, JSON: fakeType{ID: "1"}, } + fakeResponse3 = &util.JSONResponse{ + Code: http.StatusOK, JSON: fakeType{ID: "2"}, + } ) // TestCache creates a New Cache and tests AddTransaction & FetchTransaction func TestCache(t *testing.T) { fakeTxnCache := New() - fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) + u, _ := url.Parse("") + fakeTxnCache.AddTransaction(fakeAccessToken, fakeTxnID, u, fakeResponse) // Add entries for noise. for i := 1; i <= 100; i++ { fakeTxnCache.AddTransaction( fakeAccessToken, fakeTxnID+strconv.Itoa(i), + u, &util.JSONResponse{Code: http.StatusOK, JSON: fakeType{ID: strconv.Itoa(i)}}, ) } - testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID) + testResponse, ok := fakeTxnCache.FetchTransaction(fakeAccessToken, fakeTxnID, u) if !ok { t.Error("Failed to retrieve entry for txnID: ", fakeTxnID) } else if testResponse.JSON != fakeResponse.JSON { @@ -59,20 +77,30 @@ func TestCache(t *testing.T) { } // TestCacheScope ensures transactions with the same transaction ID are not shared -// across multiple access tokens. +// across multiple access tokens and endpoints. func TestCacheScope(t *testing.T) { cache := New() - cache.AddTransaction(fakeAccessToken, fakeTxnID, fakeResponse) - cache.AddTransaction(fakeAccessToken2, fakeTxnID, fakeResponse2) + sendEndpoint, _ := url.Parse("/send/1?accessToken=test") + sendToDeviceEndpoint, _ := url.Parse("/sendToDevice/1") + cache.AddTransaction(fakeAccessToken, fakeTxnID, sendEndpoint, fakeResponse) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint, fakeResponse2) + cache.AddTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint, fakeResponse3) - if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse.JSON, res.JSON) } - if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID); !ok { + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendEndpoint); !ok { t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) } else if res.JSON != fakeResponse2.JSON { t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) } + + // Ensure the txnID is not shared across endpoints + if res, ok := cache.FetchTransaction(fakeAccessToken2, fakeTxnID, sendToDeviceEndpoint); !ok { + t.Errorf("failed to retrieve entry for (%s, %s)", fakeAccessToken, fakeTxnID) + } else if res.JSON != fakeResponse3.JSON { + t.Errorf("Wrong cache entry for (%s, %s). Expected: %v; got: %v", fakeAccessToken, fakeTxnID, fakeResponse2.JSON, res.JSON) + } } |