aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsanthoshivan23 <47689668+santhoshivan23@users.noreply.github.com>2023-06-22 22:07:21 +0530
committerGitHub <noreply@github.com>2023-06-22 16:37:21 +0000
commit45082d4dcefadceada1b4374f3876365887cfd4a (patch)
tree899fda990a71f16eb05073098196a2a1a1218bd3
parenta734b112c6577a23b87c6b54c50fb2e9a629cf2b (diff)
feat: admin APIs for token authenticated registration (#3101)
### Pull Request Checklist <!-- Please read https://matrix-org.github.io/dendrite/development/contributing before submitting your pull request --> * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Santhoshivan Amudhan santhoshivan23@gmail.com`
-rw-r--r--clientapi/admin_test.go638
-rw-r--r--clientapi/api/api.go8
-rw-r--r--clientapi/routing/admin.go242
-rw-r--r--clientapi/routing/routing.go30
-rw-r--r--setup/config/config_clientapi.go5
-rw-r--r--userapi/api/api.go6
-rw-r--r--userapi/internal/user_api.go32
-rw-r--r--userapi/storage/interface.go11
-rw-r--r--userapi/storage/postgres/registration_tokens_table.go222
-rw-r--r--userapi/storage/postgres/storage.go5
-rw-r--r--userapi/storage/shared/storage.go38
-rw-r--r--userapi/storage/sqlite3/registration_tokens_table.go222
-rw-r--r--userapi/storage/sqlite3/storage.go6
-rw-r--r--userapi/storage/tables/interface.go10
14 files changed, 1474 insertions, 1 deletions
diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go
index 1145cb12..9d2acd68 100644
--- a/clientapi/admin_test.go
+++ b/clientapi/admin_test.go
@@ -2,6 +2,7 @@ package clientapi
import (
"context"
+ "fmt"
"net/http"
"net/http/httptest"
"reflect"
@@ -23,12 +24,649 @@ import (
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
+ capi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
uapi "github.com/matrix-org/dendrite/userapi/api"
)
+func TestAdminCreateToken(t *testing.T) {
+ aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
+ bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ cfg.ClientAPI.RegistrationRequiresToken = true
+ defer close()
+ natsInstance := jetstream.NATSInstance{}
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+ accessTokens := map[*test.User]userDevice{
+ aliceAdmin: {},
+ bob: {},
+ }
+ createAccessTokens(t, accessTokens, userAPI, ctx, routers)
+ testCases := []struct {
+ name string
+ requestingUser *test.User
+ requestOpt test.HTTPRequestOpt
+ wantOK bool
+ withHeader bool
+ }{
+ {
+ name: "Missing auth",
+ requestingUser: bob,
+ wantOK: false,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token1",
+ },
+ ),
+ },
+ {
+ name: "Bob is denied access",
+ requestingUser: bob,
+ wantOK: false,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token2",
+ },
+ ),
+ },
+ {
+ name: "Alice can create a token without specifyiing any information",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{}),
+ },
+ {
+ name: "Alice can to create a token specifying a name",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token3",
+ },
+ ),
+ },
+ {
+ name: "Alice cannot to create a token that already exists",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token3",
+ },
+ ),
+ },
+ {
+ name: "Alice can create a token specifying valid params",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token4",
+ "uses_allowed": 5,
+ "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond),
+ },
+ ),
+ },
+ {
+ name: "Alice cannot create a token specifying invalid name",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token@",
+ },
+ ),
+ },
+ {
+ name: "Alice cannot create a token specifying invalid uses_allowed",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token5",
+ "uses_allowed": -1,
+ },
+ ),
+ },
+ {
+ name: "Alice cannot create a token specifying invalid expiry_time",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "token": "token6",
+ "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond),
+ },
+ ),
+ },
+ {
+ name: "Alice cannot to create a token specifying invalid length",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "length": 80,
+ },
+ ),
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new")
+ if tc.requestOpt != nil {
+ req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/registrationTokens/new", tc.requestOpt)
+ }
+ if tc.withHeader {
+ req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
+ }
+ rec := httptest.NewRecorder()
+ routers.DendriteAdmin.ServeHTTP(rec, req)
+ t.Logf("%s", rec.Body.String())
+ if tc.wantOK && rec.Code != http.StatusOK {
+ t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
+ }
+ })
+ }
+ })
+}
+
+func TestAdminListRegistrationTokens(t *testing.T) {
+ aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
+ bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ cfg.ClientAPI.RegistrationRequiresToken = true
+ defer close()
+ natsInstance := jetstream.NATSInstance{}
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+ accessTokens := map[*test.User]userDevice{
+ aliceAdmin: {},
+ bob: {},
+ }
+ tokens := []capi.RegistrationToken{
+ {
+ Token: getPointer("valid"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ {
+ Token: getPointer("invalid"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ }
+ for _, tkn := range tokens {
+ tkn := tkn
+ userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
+ }
+ createAccessTokens(t, accessTokens, userAPI, ctx, routers)
+ testCases := []struct {
+ name string
+ requestingUser *test.User
+ valid string
+ isValidSpecified bool
+ wantOK bool
+ withHeader bool
+ }{
+ {
+ name: "Missing auth",
+ requestingUser: bob,
+ wantOK: false,
+ isValidSpecified: false,
+ },
+ {
+ name: "Bob is denied access",
+ requestingUser: bob,
+ wantOK: false,
+ withHeader: true,
+ isValidSpecified: false,
+ },
+ {
+ name: "Alice can list all tokens",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ },
+ {
+ name: "Alice can list all valid tokens",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ valid: "true",
+ isValidSpecified: true,
+ },
+ {
+ name: "Alice can list all invalid tokens",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ valid: "false",
+ isValidSpecified: true,
+ },
+ {
+ name: "No response when valid has a bad value",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ valid: "trueee",
+ isValidSpecified: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ var path string
+ if tc.isValidSpecified {
+ path = fmt.Sprintf("/_dendrite/admin/registrationTokens?valid=%v", tc.valid)
+ } else {
+ path = "/_dendrite/admin/registrationTokens"
+ }
+ req := test.NewRequest(t, http.MethodGet, path)
+ if tc.withHeader {
+ req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
+ }
+ rec := httptest.NewRecorder()
+ routers.DendriteAdmin.ServeHTTP(rec, req)
+ t.Logf("%s", rec.Body.String())
+ if tc.wantOK && rec.Code != http.StatusOK {
+ t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
+ }
+ })
+ }
+ })
+}
+
+func TestAdminGetRegistrationToken(t *testing.T) {
+ aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
+ bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ cfg.ClientAPI.RegistrationRequiresToken = true
+ defer close()
+ natsInstance := jetstream.NATSInstance{}
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+ accessTokens := map[*test.User]userDevice{
+ aliceAdmin: {},
+ bob: {},
+ }
+ tokens := []capi.RegistrationToken{
+ {
+ Token: getPointer("alice_token1"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ {
+ Token: getPointer("alice_token2"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ }
+ for _, tkn := range tokens {
+ tkn := tkn
+ userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
+ }
+ createAccessTokens(t, accessTokens, userAPI, ctx, routers)
+ testCases := []struct {
+ name string
+ requestingUser *test.User
+ token string
+ wantOK bool
+ withHeader bool
+ }{
+ {
+ name: "Missing auth",
+ requestingUser: bob,
+ wantOK: false,
+ },
+ {
+ name: "Bob is denied access",
+ requestingUser: bob,
+ wantOK: false,
+ withHeader: true,
+ },
+ {
+ name: "Alice can GET alice_token1",
+ token: "alice_token1",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ },
+ {
+ name: "Alice can GET alice_token2",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ token: "alice_token2",
+ },
+ {
+ name: "Alice cannot GET a token that does not exists",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ token: "alice_token3",
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
+ req := test.NewRequest(t, http.MethodGet, path)
+ if tc.withHeader {
+ req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
+ }
+ rec := httptest.NewRecorder()
+ routers.DendriteAdmin.ServeHTTP(rec, req)
+ t.Logf("%s", rec.Body.String())
+ if tc.wantOK && rec.Code != http.StatusOK {
+ t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
+ }
+ })
+ }
+ })
+}
+
+func TestAdminDeleteRegistrationToken(t *testing.T) {
+ aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
+ bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ cfg.ClientAPI.RegistrationRequiresToken = true
+ defer close()
+ natsInstance := jetstream.NATSInstance{}
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+ accessTokens := map[*test.User]userDevice{
+ aliceAdmin: {},
+ bob: {},
+ }
+ tokens := []capi.RegistrationToken{
+ {
+ Token: getPointer("alice_token1"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ {
+ Token: getPointer("alice_token2"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ }
+ for _, tkn := range tokens {
+ tkn := tkn
+ userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
+ }
+ createAccessTokens(t, accessTokens, userAPI, ctx, routers)
+ testCases := []struct {
+ name string
+ requestingUser *test.User
+ token string
+ wantOK bool
+ withHeader bool
+ }{
+ {
+ name: "Missing auth",
+ requestingUser: bob,
+ wantOK: false,
+ },
+ {
+ name: "Bob is denied access",
+ requestingUser: bob,
+ wantOK: false,
+ withHeader: true,
+ },
+ {
+ name: "Alice can DELETE alice_token1",
+ token: "alice_token1",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ },
+ {
+ name: "Alice can DELETE alice_token2",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ token: "alice_token2",
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
+ req := test.NewRequest(t, http.MethodDelete, path)
+ if tc.withHeader {
+ req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
+ }
+ rec := httptest.NewRecorder()
+ routers.DendriteAdmin.ServeHTTP(rec, req)
+ t.Logf("%s", rec.Body.String())
+ if tc.wantOK && rec.Code != http.StatusOK {
+ t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
+ }
+ })
+ }
+ })
+}
+
+func TestAdminUpdateRegistrationToken(t *testing.T) {
+ aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
+ bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ cfg.ClientAPI.RegistrationRequiresToken = true
+ defer close()
+ natsInstance := jetstream.NATSInstance{}
+ routers := httputil.NewRouters()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
+ userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
+ AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
+ accessTokens := map[*test.User]userDevice{
+ aliceAdmin: {},
+ bob: {},
+ }
+ createAccessTokens(t, accessTokens, userAPI, ctx, routers)
+ tokens := []capi.RegistrationToken{
+ {
+ Token: getPointer("alice_token1"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ {
+ Token: getPointer("alice_token2"),
+ UsesAllowed: getPointer(int32(10)),
+ ExpiryTime: getPointer(time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond)),
+ Pending: getPointer(int32(0)),
+ Completed: getPointer(int32(0)),
+ },
+ }
+ for _, tkn := range tokens {
+ tkn := tkn
+ userAPI.PerformAdminCreateRegistrationToken(ctx, &tkn)
+ }
+ testCases := []struct {
+ name string
+ requestingUser *test.User
+ method string
+ token string
+ requestOpt test.HTTPRequestOpt
+ wantOK bool
+ withHeader bool
+ }{
+ {
+ name: "Missing auth",
+ requestingUser: bob,
+ wantOK: false,
+ token: "alice_token1",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "uses_allowed": 10,
+ },
+ ),
+ },
+ {
+ name: "Bob is denied access",
+ requestingUser: bob,
+ wantOK: false,
+ withHeader: true,
+ token: "alice_token1",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "uses_allowed": 10,
+ },
+ ),
+ },
+ {
+ name: "Alice can UPDATE a token's uses_allowed property",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ token: "alice_token1",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "uses_allowed": 10,
+ }),
+ },
+ {
+ name: "Alice can UPDATE a token's expiry_time property",
+ requestingUser: aliceAdmin,
+ wantOK: true,
+ withHeader: true,
+ token: "alice_token2",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "expiry_time": time.Now().Add(5*24*time.Hour).UnixNano() / int64(time.Millisecond),
+ },
+ ),
+ },
+ {
+ name: "Alice can UPDATE a token's uses_allowed and expiry_time property",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ token: "alice_token1",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "uses_allowed": 20,
+ "expiry_time": time.Now().Add(10*24*time.Hour).UnixNano() / int64(time.Millisecond),
+ },
+ ),
+ },
+ {
+ name: "Alice CANNOT update a token with invalid properties",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ token: "alice_token2",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "uses_allowed": -5,
+ "expiry_time": time.Now().Add(-1*5*24*time.Hour).UnixNano() / int64(time.Millisecond),
+ },
+ ),
+ },
+ {
+ name: "Alice CANNOT UPDATE a token that does not exist",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ token: "alice_token9",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "uses_allowed": 100,
+ },
+ ),
+ },
+ {
+ name: "Alice can UPDATE token specifying uses_allowed as null - Valid for infinite uses",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ token: "alice_token1",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "uses_allowed": nil,
+ },
+ ),
+ },
+ {
+ name: "Alice can UPDATE token specifying expiry_time AS null - Valid for infinite time",
+ requestingUser: aliceAdmin,
+ wantOK: false,
+ withHeader: true,
+ token: "alice_token1",
+ requestOpt: test.WithJSONBody(t, map[string]interface{}{
+ "expiry_time": nil,
+ },
+ ),
+ },
+ }
+
+ for _, tc := range testCases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ path := fmt.Sprintf("/_dendrite/admin/registrationTokens/%s", tc.token)
+ req := test.NewRequest(t, http.MethodPut, path)
+ if tc.requestOpt != nil {
+ req = test.NewRequest(t, http.MethodPut, path, tc.requestOpt)
+ }
+ if tc.withHeader {
+ req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser].accessToken)
+ }
+ rec := httptest.NewRecorder()
+ routers.DendriteAdmin.ServeHTTP(rec, req)
+ t.Logf("%s", rec.Body.String())
+ if tc.wantOK && rec.Code != http.StatusOK {
+ t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String())
+ }
+ })
+ }
+ })
+}
+
+func getPointer[T any](s T) *T {
+ return &s
+}
+
func TestAdminResetPassword(t *testing.T) {
aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin))
bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser))
diff --git a/clientapi/api/api.go b/clientapi/api/api.go
index 23974c86..28ff593f 100644
--- a/clientapi/api/api.go
+++ b/clientapi/api/api.go
@@ -21,3 +21,11 @@ type ExtraPublicRoomsProvider interface {
// Rooms returns the extra rooms. This is called on-demand by clients, so cache appropriately.
Rooms() []fclient.PublicRoom
}
+
+type RegistrationToken struct {
+ Token *string `json:"token"`
+ UsesAllowed *int32 `json:"uses_allowed"`
+ Pending *int32 `json:"pending"`
+ Completed *int32 `json:"completed"`
+ ExpiryTime *int64 `json:"expiry_time"`
+}
diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go
index 3d64454c..51966607 100644
--- a/clientapi/routing/admin.go
+++ b/clientapi/routing/admin.go
@@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"net/http"
+ "regexp"
+ "strconv"
"time"
"github.com/gorilla/mux"
@@ -16,14 +18,254 @@ import (
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
+ "golang.org/x/exp/constraints"
+ clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
)
+var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$")
+
+func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
+ if !cfg.RegistrationRequiresToken {
+ return util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"),
+ }
+ }
+ request := struct {
+ Token string `json:"token"`
+ UsesAllowed *int32 `json:"uses_allowed,omitempty"`
+ ExpiryTime *int64 `json:"expiry_time,omitempty"`
+ Length int32 `json:"length"`
+ }{}
+
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
+ }
+ }
+
+ token := request.Token
+ usesAllowed := request.UsesAllowed
+ expiryTime := request.ExpiryTime
+ length := request.Length
+
+ if len(token) == 0 {
+ if length == 0 {
+ // length not provided in request. Assign default value of 16.
+ length = 16
+ }
+ // token not present in request body. Hence, generate a random token.
+ if length <= 0 || length > 64 {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("length must be greater than zero and not greater than 64"),
+ }
+ }
+ token = util.RandomString(int(length))
+ }
+
+ if len(token) > 64 {
+ //Token present in request body, but is too long.
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("token must not be longer than 64"),
+ }
+ }
+
+ isTokenValid := validRegistrationTokenRegex.Match([]byte(token))
+ if !isTokenValid {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"),
+ }
+ }
+ // At this point, we have a valid token, either through request body or through random generation.
+ if usesAllowed != nil && *usesAllowed < 0 {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
+ }
+ }
+ if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("expiry_time must not be in the past"),
+ }
+ }
+ pending := int32(0)
+ completed := int32(0)
+ // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB)
+ registrationToken := &clientapi.RegistrationToken{
+ Token: &token,
+ UsesAllowed: usesAllowed,
+ Pending: &pending,
+ Completed: &completed,
+ ExpiryTime: expiryTime,
+ }
+ created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken)
+ if !created {
+ return util.JSONResponse{
+ Code: http.StatusConflict,
+ JSON: map[string]string{
+ "error": fmt.Sprintf("token: %s already exists", token),
+ },
+ }
+ }
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: err,
+ }
+ }
+ return util.JSONResponse{
+ Code: 200,
+ JSON: map[string]interface{}{
+ "token": token,
+ "uses_allowed": getReturnValue(usesAllowed),
+ "pending": pending,
+ "completed": completed,
+ "expiry_time": getReturnValue(expiryTime),
+ },
+ }
+}
+
+func getReturnValue[t constraints.Integer](in *t) any {
+ if in == nil {
+ return nil
+ }
+ return *in
+}
+
+func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
+ queryParams := req.URL.Query()
+ returnAll := true
+ valid := true
+ validQuery, ok := queryParams["valid"]
+ if ok {
+ returnAll = false
+ validValue, err := strconv.ParseBool(validQuery[0])
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("invalid 'valid' query parameter"),
+ }
+ }
+ valid = validValue
+ }
+ tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.ErrorUnknown,
+ }
+ }
+ return util.JSONResponse{
+ Code: 200,
+ JSON: map[string]interface{}{
+ "registration_tokens": tokens,
+ },
+ }
+}
+
+func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ tokenText := vars["token"]
+ token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusNotFound,
+ JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
+ }
+ }
+ return util.JSONResponse{
+ Code: 200,
+ JSON: token,
+ }
+}
+
+func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ tokenText := vars["token"]
+ err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: err,
+ }
+ }
+ return util.JSONResponse{
+ Code: 200,
+ JSON: map[string]interface{}{},
+ }
+}
+
+func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse {
+ vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
+ if err != nil {
+ return util.ErrorResponse(err)
+ }
+ tokenText := vars["token"]
+ request := make(map[string]*int64)
+ if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)),
+ }
+ }
+ newAttributes := make(map[string]interface{})
+ usesAllowed, ok := request["uses_allowed"]
+ if ok {
+ // Only add usesAllowed to newAtrributes if it is present and valid
+ if usesAllowed != nil && *usesAllowed < 0 {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"),
+ }
+ }
+ newAttributes["usesAllowed"] = usesAllowed
+ }
+ expiryTime, ok := request["expiry_time"]
+ if ok {
+ // Only add expiryTime to newAtrributes if it is present and valid
+ if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.BadJSON("expiry_time must not be in the past"),
+ }
+ }
+ newAttributes["expiryTime"] = expiryTime
+ }
+ if len(newAttributes) == 0 {
+ // No attributes to update. Return existing token
+ return AdminGetRegistrationToken(req, cfg, userAPI)
+ }
+ updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusNotFound,
+ JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)),
+ }
+ }
+ return util.JSONResponse{
+ Code: 200,
+ JSON: *updatedToken,
+ }
+}
+
func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index d3f19cae..ab4aefdd 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -162,6 +162,36 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
}
+ dendriteAdminRouter.Handle("/admin/registrationTokens/new",
+ httputil.MakeAdminAPI("admin_registration_tokens_new", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return AdminCreateNewRegistrationToken(req, cfg, userAPI)
+ }),
+ ).Methods(http.MethodPost, http.MethodOptions)
+
+ dendriteAdminRouter.Handle("/admin/registrationTokens",
+ httputil.MakeAdminAPI("admin_list_registration_tokens", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ return AdminListRegistrationTokens(req, cfg, userAPI)
+ }),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ dendriteAdminRouter.Handle("/admin/registrationTokens/{token}",
+ httputil.MakeAdminAPI("admin_get_registration_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ switch req.Method {
+ case http.MethodGet:
+ return AdminGetRegistrationToken(req, cfg, userAPI)
+ case http.MethodPut:
+ return AdminUpdateRegistrationToken(req, cfg, userAPI)
+ case http.MethodDelete:
+ return AdminDeleteRegistrationToken(req, cfg, userAPI)
+ default:
+ return util.MatrixErrorResponse(
+ 404,
+ string(spec.ErrorNotFound),
+ "unknown method",
+ )
+ }
+ }),
+ ).Methods(http.MethodGet, http.MethodPut, http.MethodDelete, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/evacuateRoom/{roomID}",
httputil.MakeAdminAPI("admin_evacuate_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go
index b6c74a75..44136e2a 100644
--- a/setup/config/config_clientapi.go
+++ b/setup/config/config_clientapi.go
@@ -13,6 +13,10 @@ type ClientAPI struct {
// secrets)
RegistrationDisabled bool `yaml:"registration_disabled"`
+ // If set, requires users to submit a token during registration.
+ // Tokens can be managed using admin API.
+ RegistrationRequiresToken bool `yaml:"registration_requires_token"`
+
// Enable registration without captcha verification or shared secret.
// This option is populated by the -really-enable-open-registration
// command line parameter as it is not recommended.
@@ -56,6 +60,7 @@ type ClientAPI struct {
func (c *ClientAPI) Defaults(opts DefaultOpts) {
c.RegistrationSharedSecret = ""
+ c.RegistrationRequiresToken = false
c.RecaptchaPublicKey = ""
c.RecaptchaPrivateKey = ""
c.RecaptchaEnabled = false
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 05040264..a0dce975 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -27,6 +27,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
+ clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
)
@@ -94,6 +95,11 @@ type ClientUserAPI interface {
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
+ PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
+ PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
+ PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
+ PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error
+ PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go
index 32f3d84b..4305c13a 100644
--- a/userapi/internal/user_api.go
+++ b/userapi/internal/user_api.go
@@ -33,6 +33,7 @@ import (
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
+ clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
@@ -63,6 +64,37 @@ type UserInternalAPI struct {
Updater *DeviceListUpdater
}
+func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) {
+ exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token)
+ if err != nil {
+ return false, err
+ }
+ if exists {
+ return false, fmt.Errorf("token: %s already exists", *registrationToken.Token)
+ }
+ _, err = a.DB.InsertRegistrationToken(ctx, registrationToken)
+ if err != nil {
+ return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token)
+ }
+ return true, nil
+}
+
+func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
+ return a.DB.ListRegistrationTokens(ctx, returnAll, valid)
+}
+
+func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
+ return a.DB.GetRegistrationToken(ctx, tokenString)
+}
+
+func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error {
+ return a.DB.DeleteRegistrationToken(ctx, tokenString)
+}
+
+func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
+ return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
+}
+
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index 4f5e99a8..125b3158 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -23,6 +23,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
+ clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/userapi/api"
@@ -30,6 +31,15 @@ import (
"github.com/matrix-org/dendrite/userapi/types"
)
+type RegistrationTokens interface {
+ RegistrationTokenExists(ctx context.Context, token string) (bool, error)
+ InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
+ ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
+ GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
+ DeleteRegistrationToken(ctx context.Context, tokenString string) error
+ UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
+}
+
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
@@ -144,6 +154,7 @@ type UserDatabase interface {
Pusher
Statistics
ThreePID
+ RegistrationTokens
}
type KeyChangeDatabase interface {
diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go
new file mode 100644
index 00000000..3c3e3fdd
--- /dev/null
+++ b/userapi/storage/postgres/registration_tokens_table.go
@@ -0,0 +1,222 @@
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/matrix-org/dendrite/clientapi/api"
+ internal "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "golang.org/x/exp/constraints"
+)
+
+const registrationTokensSchema = `
+CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
+ token TEXT PRIMARY KEY,
+ pending BIGINT,
+ completed BIGINT,
+ uses_allowed BIGINT,
+ expiry_time BIGINT
+);
+`
+
+const selectTokenSQL = "" +
+ "SELECT token FROM userapi_registration_tokens WHERE token = $1"
+
+const insertTokenSQL = "" +
+ "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
+
+const listAllTokensSQL = "" +
+ "SELECT * FROM userapi_registration_tokens"
+
+const listValidTokensSQL = "" +
+ "SELECT * FROM userapi_registration_tokens WHERE" +
+ "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
+ "(expiry_time > $1 OR expiry_time IS NULL)"
+
+const listInvalidTokensSQL = "" +
+ "SELECT * FROM userapi_registration_tokens WHERE" +
+ "(uses_allowed <= pending + completed OR expiry_time <= $1)"
+
+const getTokenSQL = "" +
+ "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
+
+const deleteTokenSQL = "" +
+ "DELETE FROM userapi_registration_tokens WHERE token = $1"
+
+const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
+ "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
+
+const updateTokenUsesAllowedSQL = "" +
+ "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
+
+const updateTokenExpiryTimeSQL = "" +
+ "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
+
+type registrationTokenStatements struct {
+ selectTokenStatement *sql.Stmt
+ insertTokenStatement *sql.Stmt
+ listAllTokensStatement *sql.Stmt
+ listValidTokensStatement *sql.Stmt
+ listInvalidTokenStatement *sql.Stmt
+ getTokenStatement *sql.Stmt
+ deleteTokenStatement *sql.Stmt
+ updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
+ updateTokenUsesAllowedStatement *sql.Stmt
+ updateTokenExpiryTimeStatement *sql.Stmt
+}
+
+func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
+ s := &registrationTokenStatements{}
+ _, err := db.Exec(registrationTokensSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.selectTokenStatement, selectTokenSQL},
+ {&s.insertTokenStatement, insertTokenSQL},
+ {&s.listAllTokensStatement, listAllTokensSQL},
+ {&s.listValidTokensStatement, listValidTokensSQL},
+ {&s.listInvalidTokenStatement, listInvalidTokensSQL},
+ {&s.getTokenStatement, getTokenSQL},
+ {&s.deleteTokenStatement, deleteTokenSQL},
+ {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
+ {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
+ {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
+ }.Prepare(db)
+}
+
+func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
+ var existingToken string
+ stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
+ err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return false, nil
+ }
+ return false, err
+ }
+ return true, nil
+}
+
+func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
+ stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
+ _, err := stmt.ExecContext(
+ ctx,
+ *registrationToken.Token,
+ getInsertValue(registrationToken.UsesAllowed),
+ getInsertValue(registrationToken.ExpiryTime),
+ *registrationToken.Pending,
+ *registrationToken.Completed)
+ if err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+func getInsertValue[t constraints.Integer](in *t) any {
+ if in == nil {
+ return nil
+ }
+ return *in
+}
+
+func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
+ var stmt *sql.Stmt
+ var tokens []api.RegistrationToken
+ var tokenString string
+ var pending, completed, usesAllowed *int32
+ var expiryTime *int64
+ var rows *sql.Rows
+ var err error
+ if returnAll {
+ stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
+ rows, err = stmt.QueryContext(ctx)
+ } else if valid {
+ stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
+ rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
+ } else {
+ stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
+ rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
+ }
+ if err != nil {
+ return tokens, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
+ for rows.Next() {
+ err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
+ if err != nil {
+ return tokens, err
+ }
+ tokenString := tokenString
+ pending := pending
+ completed := completed
+ usesAllowed := usesAllowed
+ expiryTime := expiryTime
+
+ tokenMap := api.RegistrationToken{
+ Token: &tokenString,
+ Pending: pending,
+ Completed: completed,
+ UsesAllowed: usesAllowed,
+ ExpiryTime: expiryTime,
+ }
+ tokens = append(tokens, tokenMap)
+ }
+ return tokens, rows.Err()
+}
+
+func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
+ stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
+ var pending, completed, usesAllowed *int32
+ var expiryTime *int64
+ err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
+ if err != nil {
+ return nil, err
+ }
+ token := api.RegistrationToken{
+ Token: &tokenString,
+ Pending: pending,
+ Completed: completed,
+ UsesAllowed: usesAllowed,
+ ExpiryTime: expiryTime,
+ }
+ return &token, nil
+}
+
+func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
+ stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
+ _, err := stmt.ExecContext(ctx, tokenString)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
+ var stmt *sql.Stmt
+ usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
+ expiryTime, expiryTimePresent := newAttributes["expiryTime"]
+ if usesAllowedPresent && expiryTimePresent {
+ stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
+ _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
+ if err != nil {
+ return nil, err
+ }
+ } else if usesAllowedPresent {
+ stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
+ _, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
+ if err != nil {
+ return nil, err
+ }
+ } else if expiryTimePresent {
+ stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
+ _, err := stmt.ExecContext(ctx, tokenString, expiryTime)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return s.GetRegistrationToken(ctx, tx, tokenString)
+}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index 72e7c9cd..d01ccc77 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
return nil, err
}
+ registationTokensTable, err := NewPostgresRegistrationTokensTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err)
+ }
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
@@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
ThreePIDs: threePIDTable,
Pushers: pusherTable,
Notifications: notificationsTable,
+ RegistrationTokens: registationTokensTable,
Stats: statsTable,
ServerName: serverName,
DB: db,
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 537bbbf4..b7acb203 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -31,6 +31,7 @@ import (
"github.com/matrix-org/gomatrixserverlib/spec"
"golang.org/x/crypto/bcrypt"
+ clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
@@ -43,6 +44,7 @@ import (
type Database struct {
DB *sql.DB
Writer sqlutil.Writer
+ RegistrationTokens tables.RegistrationTokensTable
Accounts tables.AccountsTable
Profiles tables.ProfileTable
AccountDatas tables.AccountDataTable
@@ -78,6 +80,42 @@ const (
loginTokenByteLength = 32
)
+func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
+ return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token)
+}
+
+func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken)
+ return err
+ })
+ return
+}
+
+func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
+ return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid)
+}
+
+func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
+ return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString)
+}
+
+func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
+ return err
+ })
+ return
+}
+
+func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes)
+ return err
+ })
+ return
+}
+
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go
new file mode 100644
index 00000000..89795473
--- /dev/null
+++ b/userapi/storage/sqlite3/registration_tokens_table.go
@@ -0,0 +1,222 @@
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/matrix-org/dendrite/clientapi/api"
+ internal "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "golang.org/x/exp/constraints"
+)
+
+const registrationTokensSchema = `
+CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
+ token TEXT PRIMARY KEY,
+ pending BIGINT,
+ completed BIGINT,
+ uses_allowed BIGINT,
+ expiry_time BIGINT
+);
+`
+
+const selectTokenSQL = "" +
+ "SELECT token FROM userapi_registration_tokens WHERE token = $1"
+
+const insertTokenSQL = "" +
+ "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
+
+const listAllTokensSQL = "" +
+ "SELECT * FROM userapi_registration_tokens"
+
+const listValidTokensSQL = "" +
+ "SELECT * FROM userapi_registration_tokens WHERE" +
+ "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
+ "(expiry_time > $1 OR expiry_time IS NULL)"
+
+const listInvalidTokensSQL = "" +
+ "SELECT * FROM userapi_registration_tokens WHERE" +
+ "(uses_allowed <= pending + completed OR expiry_time <= $1)"
+
+const getTokenSQL = "" +
+ "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
+
+const deleteTokenSQL = "" +
+ "DELETE FROM userapi_registration_tokens WHERE token = $1"
+
+const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
+ "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
+
+const updateTokenUsesAllowedSQL = "" +
+ "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
+
+const updateTokenExpiryTimeSQL = "" +
+ "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
+
+type registrationTokenStatements struct {
+ selectTokenStatement *sql.Stmt
+ insertTokenStatement *sql.Stmt
+ listAllTokensStatement *sql.Stmt
+ listValidTokensStatement *sql.Stmt
+ listInvalidTokenStatement *sql.Stmt
+ getTokenStatement *sql.Stmt
+ deleteTokenStatement *sql.Stmt
+ updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
+ updateTokenUsesAllowedStatement *sql.Stmt
+ updateTokenExpiryTimeStatement *sql.Stmt
+}
+
+func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
+ s := &registrationTokenStatements{}
+ _, err := db.Exec(registrationTokensSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.selectTokenStatement, selectTokenSQL},
+ {&s.insertTokenStatement, insertTokenSQL},
+ {&s.listAllTokensStatement, listAllTokensSQL},
+ {&s.listValidTokensStatement, listValidTokensSQL},
+ {&s.listInvalidTokenStatement, listInvalidTokensSQL},
+ {&s.getTokenStatement, getTokenSQL},
+ {&s.deleteTokenStatement, deleteTokenSQL},
+ {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
+ {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
+ {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
+ }.Prepare(db)
+}
+
+func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
+ var existingToken string
+ stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
+ err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return false, nil
+ }
+ return false, err
+ }
+ return true, nil
+}
+
+func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
+ stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
+ _, err := stmt.ExecContext(
+ ctx,
+ *registrationToken.Token,
+ getInsertValue(registrationToken.UsesAllowed),
+ getInsertValue(registrationToken.ExpiryTime),
+ *registrationToken.Pending,
+ *registrationToken.Completed)
+ if err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+func getInsertValue[t constraints.Integer](in *t) any {
+ if in == nil {
+ return nil
+ }
+ return *in
+}
+
+func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
+ var stmt *sql.Stmt
+ var tokens []api.RegistrationToken
+ var tokenString string
+ var pending, completed, usesAllowed *int32
+ var expiryTime *int64
+ var rows *sql.Rows
+ var err error
+ if returnAll {
+ stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
+ rows, err = stmt.QueryContext(ctx)
+ } else if valid {
+ stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
+ rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
+ } else {
+ stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
+ rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
+ }
+ if err != nil {
+ return tokens, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
+ for rows.Next() {
+ err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
+ if err != nil {
+ return tokens, err
+ }
+ tokenString := tokenString
+ pending := pending
+ completed := completed
+ usesAllowed := usesAllowed
+ expiryTime := expiryTime
+
+ tokenMap := api.RegistrationToken{
+ Token: &tokenString,
+ Pending: pending,
+ Completed: completed,
+ UsesAllowed: usesAllowed,
+ ExpiryTime: expiryTime,
+ }
+ tokens = append(tokens, tokenMap)
+ }
+ return tokens, rows.Err()
+}
+
+func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
+ stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
+ var pending, completed, usesAllowed *int32
+ var expiryTime *int64
+ err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
+ if err != nil {
+ return nil, err
+ }
+ token := api.RegistrationToken{
+ Token: &tokenString,
+ Pending: pending,
+ Completed: completed,
+ UsesAllowed: usesAllowed,
+ ExpiryTime: expiryTime,
+ }
+ return &token, nil
+}
+
+func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
+ stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
+ _, err := stmt.ExecContext(ctx, tokenString)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
+ var stmt *sql.Stmt
+ usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
+ expiryTime, expiryTimePresent := newAttributes["expiryTime"]
+ if usesAllowedPresent && expiryTimePresent {
+ stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
+ _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
+ if err != nil {
+ return nil, err
+ }
+ } else if usesAllowedPresent {
+ stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
+ _, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
+ if err != nil {
+ return nil, err
+ }
+ } else if expiryTimePresent {
+ stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
+ _, err := stmt.ExecContext(ctx, tokenString, expiryTime)
+ if err != nil {
+ return nil, err
+ }
+ }
+ return s.GetRegistrationToken(ctx, tx, tokenString)
+}
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index acd9678f..48f5c842 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -50,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
if err = m.Up(ctx); err != nil {
return nil, err
}
-
+ registationTokensTable, err := NewSQLiteRegistrationTokensTable(db)
+ if err != nil {
+ return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err)
+ }
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
@@ -130,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
LoginTokenLifetime: loginTokenLifetime,
BcryptCost: bcryptCost,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
+ RegistrationTokens: registationTokensTable,
}, nil
}
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 3c6214e7..3a0be73e 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -25,10 +25,20 @@ import (
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
+ clientapi "github.com/matrix-org/dendrite/clientapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/types"
)
+type RegistrationTokensTable interface {
+ RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error)
+ InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error)
+ ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
+ GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error)
+ DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error
+ UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
+}
+
type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)