diff options
Diffstat (limited to 'clientapi/routing/admin.go')
-rw-r--r-- | clientapi/routing/admin.go | 242 |
1 files changed, 242 insertions, 0 deletions
diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 3d64454c..51966607 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "net/http" + "regexp" + "strconv" "time" "github.com/gorilla/mux" @@ -16,14 +18,254 @@ import ( "github.com/matrix-org/util" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" + "golang.org/x/exp/constraints" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" ) +var validRegistrationTokenRegex = regexp.MustCompile("^[[:ascii:][:digit:]_]*$") + +func AdminCreateNewRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + if !cfg.RegistrationRequiresToken { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("Registration via tokens is not enabled on this homeserver"), + } + } + request := struct { + Token string `json:"token"` + UsesAllowed *int32 `json:"uses_allowed,omitempty"` + ExpiryTime *int64 `json:"expiry_time,omitempty"` + Length int32 `json:"length"` + }{} + + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } + } + + token := request.Token + usesAllowed := request.UsesAllowed + expiryTime := request.ExpiryTime + length := request.Length + + if len(token) == 0 { + if length == 0 { + // length not provided in request. Assign default value of 16. + length = 16 + } + // token not present in request body. Hence, generate a random token. + if length <= 0 || length > 64 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("length must be greater than zero and not greater than 64"), + } + } + token = util.RandomString(int(length)) + } + + if len(token) > 64 { + //Token present in request body, but is too long. + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must not be longer than 64"), + } + } + + isTokenValid := validRegistrationTokenRegex.Match([]byte(token)) + if !isTokenValid { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("token must consist only of characters matched by the regex [A-Za-z0-9-_]"), + } + } + // At this point, we have a valid token, either through request body or through random generation. + if usesAllowed != nil && *usesAllowed < 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } + } + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } + } + pending := int32(0) + completed := int32(0) + // If usesAllowed or expiryTime is 0, it means they are not present in the request. NULL (indicating unlimited uses / no expiration will be persisted in DB) + registrationToken := &clientapi.RegistrationToken{ + Token: &token, + UsesAllowed: usesAllowed, + Pending: &pending, + Completed: &completed, + ExpiryTime: expiryTime, + } + created, err := userAPI.PerformAdminCreateRegistrationToken(req.Context(), registrationToken) + if !created { + return util.JSONResponse{ + Code: http.StatusConflict, + JSON: map[string]string{ + "error": fmt.Sprintf("token: %s already exists", token), + }, + } + } + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "token": token, + "uses_allowed": getReturnValue(usesAllowed), + "pending": pending, + "completed": completed, + "expiry_time": getReturnValue(expiryTime), + }, + } +} + +func getReturnValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func AdminListRegistrationTokens(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + queryParams := req.URL.Query() + returnAll := true + valid := true + validQuery, ok := queryParams["valid"] + if ok { + returnAll = false + validValue, err := strconv.ParseBool(validQuery[0]) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("invalid 'valid' query parameter"), + } + } + valid = validValue + } + tokens, err := userAPI.PerformAdminListRegistrationTokens(req.Context(), returnAll, valid) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.ErrorUnknown, + } + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{ + "registration_tokens": tokens, + }, + } +} + +func AdminGetRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + token, err := userAPI.PerformAdminGetRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } + } + return util.JSONResponse{ + Code: 200, + JSON: token, + } +} + +func AdminDeleteRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + err = userAPI.PerformAdminDeleteRegistrationToken(req.Context(), tokenText) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } +} + +func AdminUpdateRegistrationToken(req *http.Request, cfg *config.ClientAPI, userAPI userapi.ClientUserAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + tokenText := vars["token"] + request := make(map[string]*int64) + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON(fmt.Sprintf("Failed to decode request body: %s", err)), + } + } + newAttributes := make(map[string]interface{}) + usesAllowed, ok := request["uses_allowed"] + if ok { + // Only add usesAllowed to newAtrributes if it is present and valid + if usesAllowed != nil && *usesAllowed < 0 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("uses_allowed must be a non-negative integer or null"), + } + } + newAttributes["usesAllowed"] = usesAllowed + } + expiryTime, ok := request["expiry_time"] + if ok { + // Only add expiryTime to newAtrributes if it is present and valid + if expiryTime != nil && spec.Timestamp(*expiryTime).Time().Before(time.Now()) { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: spec.BadJSON("expiry_time must not be in the past"), + } + } + newAttributes["expiryTime"] = expiryTime + } + if len(newAttributes) == 0 { + // No attributes to update. Return existing token + return AdminGetRegistrationToken(req, cfg, userAPI) + } + updatedToken, err := userAPI.PerformAdminUpdateRegistrationToken(req.Context(), tokenText, newAttributes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound(fmt.Sprintf("token: %s not found", tokenText)), + } + } + return util.JSONResponse{ + Code: 200, + JSON: *updatedToken, + } +} + func AdminEvacuateRoom(req *http.Request, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { |