aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--clientapi/routing/routing.go3
-rw-r--r--federationapi/routing/routing.go48
-rw-r--r--go.mod2
-rw-r--r--go.sum4
-rw-r--r--internal/httputil/httpapi.go32
-rw-r--r--internal/sqlutil/sqlutil_test.go2
-rw-r--r--mediaapi/mediaapi.go9
-rw-r--r--mediaapi/routing/download.go164
-rw-r--r--mediaapi/routing/download_test.go30
-rw-r--r--mediaapi/routing/routing.go105
-rw-r--r--roomserver/acls/acls_test.go6
-rw-r--r--setup/monolith.go2
-rw-r--r--userapi/internal/key_api.go2
13 files changed, 364 insertions, 45 deletions
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index 60dad543..e82c8861 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -94,6 +94,7 @@ func Setup(
unstableFeatures := map[string]bool{
"org.matrix.e2e_cross_signing": true,
"org.matrix.msc2285.stable": true,
+ "org.matrix.msc3916.stable": true,
}
for _, msc := range cfg.MSCs.MSCs {
unstableFeatures["org.matrix."+msc] = true
@@ -732,7 +733,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/auth/{authType}/fallback/web",
- httputil.MakeHTMLAPI("auth_fallback", enableMetrics, func(w http.ResponseWriter, req *http.Request) {
+ httputil.MakeHTTPAPI("auth_fallback", userAPI, enableMetrics, func(w http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)
AuthFallback(w, req, vars["authType"], cfg)
}),
diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go
index 6328d165..91718efd 100644
--- a/federationapi/routing/routing.go
+++ b/federationapi/routing/routing.go
@@ -16,6 +16,7 @@ package routing
import (
"context"
+ "encoding/json"
"fmt"
"net/http"
"sync"
@@ -678,6 +679,53 @@ func MakeFedAPI(
return httputil.MakeExternalAPI(metricsName, h)
}
+// MakeFedHTTPAPI makes an http.Handler that checks matrix federation authentication.
+func MakeFedHTTPAPI(
+ serverName spec.ServerName,
+ isLocalServerName func(spec.ServerName) bool,
+ keyRing gomatrixserverlib.JSONVerifier,
+ f func(http.ResponseWriter, *http.Request),
+) http.Handler {
+ h := func(w http.ResponseWriter, req *http.Request) {
+ fedReq, errResp := fclient.VerifyHTTPRequest(
+ req, time.Now(), serverName, isLocalServerName, keyRing,
+ )
+
+ enc := json.NewEncoder(w)
+ logger := util.GetLogger(req.Context())
+ if fedReq == nil {
+
+ logger.Debugf("VerifyUserFromRequest %s -> HTTP %d", req.RemoteAddr, errResp.Code)
+ w.WriteHeader(errResp.Code)
+ if err := enc.Encode(errResp); err != nil {
+ logger.WithError(err).Error("failed to encode JSON response")
+ }
+ return
+ }
+ // add the user to Sentry, if enabled
+ hub := sentry.GetHubFromContext(req.Context())
+ if hub != nil {
+ // clone the hub, so we don't send garbage events with e.g. mismatching rooms/event_ids
+ hub = hub.Clone()
+ hub.Scope().SetTag("origin", string(fedReq.Origin()))
+ hub.Scope().SetTag("uri", fedReq.RequestURI())
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ if hub != nil {
+ hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path))
+ }
+ // re-panic to return the 500
+ panic(r)
+ }
+ }()
+
+ f(w, req)
+ }
+
+ return http.HandlerFunc(h)
+}
+
type FederationWakeups struct {
FsAPI *fedInternal.FederationInternalAPI
origins sync.Map
diff --git a/go.mod b/go.mod
index 6d04470b..a7e4d471 100644
--- a/go.mod
+++ b/go.mod
@@ -21,7 +21,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
- github.com/matrix-org/gomatrixserverlib v0.0.0-20240328203753-c2391f7113a5
+ github.com/matrix-org/gomatrixserverlib v0.0.0-20240801173829-d531860ad2cb
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.22
diff --git a/go.sum b/go.sum
index 44257993..0012386f 100644
--- a/go.sum
+++ b/go.sum
@@ -210,8 +210,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20240328203753-c2391f7113a5 h1:GuxmpyjZQoqb6UFQgKq8Td3wIITlXln/sItqp1jbTTA=
-github.com/matrix-org/gomatrixserverlib v0.0.0-20240328203753-c2391f7113a5/go.mod h1:HZGsVJ3bUE+DkZtufkH9H0mlsvbhEGK5CpX0Zlavylg=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20240801173829-d531860ad2cb h1:vb9RyAU+5r5jGTIjlteq8XK71X6Q+fqnmh8gSUUuLrI=
+github.com/matrix-org/gomatrixserverlib v0.0.0-20240801173829-d531860ad2cb/go.mod h1:HZGsVJ3bUE+DkZtufkH9H0mlsvbhEGK5CpX0Zlavylg=
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7 h1:6t8kJr8i1/1I5nNttw6nn1ryQJgzVlBmSGgPiiaTdw4=
github.com/matrix-org/pinecone v0.11.1-0.20230810010612-ea4c33717fd7/go.mod h1:ReWMS/LoVnOiRAdq9sNUC2NZnd1mZkMNB52QhpTRWjg=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go
index c78aadf8..0559fbb7 100644
--- a/internal/httputil/httpapi.go
+++ b/internal/httputil/httpapi.go
@@ -15,6 +15,7 @@
package httputil
import (
+ "encoding/json"
"fmt"
"io"
"net/http"
@@ -44,6 +45,7 @@ type BasicAuth struct {
type AuthAPIOpts struct {
GuestAccessAllowed bool
+ WithAuth bool
}
// AuthAPIOption is an option to MakeAuthAPI to add additional checks (e.g. guest access) to verify
@@ -57,6 +59,13 @@ func WithAllowGuests() AuthAPIOption {
}
}
+// WithAuth is an option to MakeHTTPAPI to add authentication.
+func WithAuth() AuthAPIOption {
+ return func(opts *AuthAPIOpts) {
+ opts.WithAuth = true
+ }
+}
+
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI(
metricsName string, userAPI userapi.QueryAcccessTokenAPI,
@@ -197,13 +206,32 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
return http.HandlerFunc(withSpan)
}
-// MakeHTMLAPI adds Span metrics to the HTML Handler function
+// MakeHTTPAPI adds Span metrics to the HTML Handler function
// This is used to serve HTML alongside JSON error messages
-func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler {
+func MakeHTTPAPI(metricsName string, userAPI userapi.QueryAcccessTokenAPI, enableMetrics bool, f func(http.ResponseWriter, *http.Request), checks ...AuthAPIOption) http.Handler {
withSpan := func(w http.ResponseWriter, req *http.Request) {
trace, ctx := internal.StartTask(req.Context(), metricsName)
defer trace.EndTask()
req = req.WithContext(ctx)
+
+ // apply additional checks, if any
+ opts := AuthAPIOpts{}
+ for _, opt := range checks {
+ opt(&opts)
+ }
+
+ if opts.WithAuth {
+ logger := util.GetLogger(req.Context())
+ _, jsonErr := auth.VerifyUserFromRequest(req, userAPI)
+ if jsonErr != nil {
+ w.WriteHeader(jsonErr.Code)
+ if err := json.NewEncoder(w).Encode(jsonErr.JSON); err != nil {
+ logger.WithError(err).Error("failed to encode JSON response")
+ }
+ return
+ }
+ }
+
f(w, req)
}
diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go
index c4075789..93b84aa2 100644
--- a/internal/sqlutil/sqlutil_test.go
+++ b/internal/sqlutil/sqlutil_test.go
@@ -218,5 +218,5 @@ func assertNoError(t *testing.T, err error, msg string) {
if err == nil {
return
}
- t.Fatalf(msg)
+ t.Fatal(msg)
}
diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go
index 3425fbce..8b843e90 100644
--- a/mediaapi/mediaapi.go
+++ b/mediaapi/mediaapi.go
@@ -15,23 +15,26 @@
package mediaapi
import (
- "github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/routing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/sirupsen/logrus"
)
// AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component.
func AddPublicRoutes(
- mediaRouter *mux.Router,
+ routers httputil.Routers,
cm *sqlutil.Connections,
cfg *config.Dendrite,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
+ fedClient fclient.FederationClient,
+ keyRing gomatrixserverlib.JSONVerifier,
) {
mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database)
if err != nil {
@@ -39,6 +42,6 @@ func AddPublicRoutes(
}
routing.Setup(
- mediaRouter, cfg, mediaDB, userAPI, client,
+ routers, cfg, mediaDB, userAPI, client, fedClient, keyRing,
)
}
diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go
index fa1c417a..c812b9d6 100644
--- a/mediaapi/routing/download.go
+++ b/mediaapi/routing/download.go
@@ -21,7 +21,9 @@ import (
"io"
"io/fs"
"mime"
+ "mime/multipart"
"net/http"
+ "net/textproto"
"net/url"
"os"
"path/filepath"
@@ -31,6 +33,7 @@ import (
"sync"
"unicode"
+ "github.com/google/uuid"
"github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/thumbnailer"
@@ -61,6 +64,9 @@ type downloadRequest struct {
ThumbnailSize types.ThumbnailSize
Logger *log.Entry
DownloadFilename string
+ multipartResponse bool // whether we need to return a multipart/mixed response (for requests coming in over federation)
+ fedClient fclient.FederationClient
+ origin spec.ServerName
}
// Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84
@@ -111,11 +117,17 @@ func Download(
cfg *config.MediaAPI,
db storage.Database,
client *fclient.Client,
+ fedClient fclient.FederationClient,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool,
customFilename string,
+ federationRequest bool,
) {
+ // This happens if we call Download for a federation request
+ if federationRequest && origin == "" {
+ origin = cfg.Matrix.ServerName
+ }
dReq := &downloadRequest{
MediaMetadata: &types.MediaMetadata{
MediaID: mediaID,
@@ -126,7 +138,10 @@ func Download(
"Origin": origin,
"MediaID": mediaID,
}),
- DownloadFilename: customFilename,
+ DownloadFilename: customFilename,
+ multipartResponse: federationRequest,
+ origin: cfg.Matrix.ServerName,
+ fedClient: fedClient,
}
if dReq.IsThumbnailRequest {
@@ -355,7 +370,7 @@ func (r *downloadRequest) respondFromLocalFile(
}).Trace("Responding with file")
responseFile = file
responseMetadata = r.MediaMetadata
- if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil {
+ if err = r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil {
return nil, err
}
}
@@ -367,14 +382,61 @@ func (r *downloadRequest) respondFromLocalFile(
" plugin-types application/pdf;" +
" style-src 'unsafe-inline';" +
" object-src 'self';"
- w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
- if _, err := io.Copy(w, responseFile); err != nil {
- return nil, fmt.Errorf("io.Copy: %w", err)
+ if !r.multipartResponse {
+ w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
+ if _, err = io.Copy(w, responseFile); err != nil {
+ return nil, fmt.Errorf("io.Copy: %w", err)
+ }
+ } else {
+ var written int64
+ written, err = multipartResponse(w, r, string(responseMetadata.ContentType), responseFile)
+ if err != nil {
+ return nil, err
+ }
+ responseMetadata.FileSizeBytes = types.FileSizeBytes(written)
}
return responseMetadata, nil
}
+func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) {
+ // Update the header to be multipart/mixed; boundary=$randomBoundary
+ boundary := uuid.NewString()
+ w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary)
+
+ w.Header().Del("Content-Length") // let Go handle the content length
+ mw := multipart.NewWriter(w)
+ defer func() {
+ if err := mw.Close(); err != nil {
+ r.Logger.WithError(err).Error("Failed to close multipart writer")
+ }
+ }()
+
+ if err := mw.SetBoundary(boundary); err != nil {
+ return 0, fmt.Errorf("failed to set multipart boundary: %w", err)
+ }
+
+ // JSON object part
+ jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {"application/json"},
+ })
+ if err != nil {
+ return 0, fmt.Errorf("failed to create json writer: %w", err)
+ }
+ if _, err = jsonWriter.Write([]byte("{}")); err != nil {
+ return 0, fmt.Errorf("failed to write to json writer: %w", err)
+ }
+
+ // media part
+ mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{
+ "Content-Type": {contentType},
+ })
+ if err != nil {
+ return 0, fmt.Errorf("failed to create media writer: %w", err)
+ }
+ return io.Copy(mediaWriter, responseFile)
+}
+
func (r *downloadRequest) addDownloadFilenameToHeaders(
w http.ResponseWriter,
responseMetadata *types.MediaMetadata,
@@ -722,8 +784,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil
}
-func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
- reader := *body
+func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) {
var contentLength int64
if contentLengthHeader != "" {
@@ -742,7 +803,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// We successfully parsed the Content-Length, so we'll return a limited
// reader that restricts us to reading only up to this size.
- reader = io.NopCloser(io.LimitReader(*body, parsedLength))
+ reader = io.NopCloser(io.LimitReader(reader, parsedLength))
contentLength = parsedLength
} else {
// Content-Length header is missing. If we have a maximum file size
@@ -751,7 +812,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
// ultimately it will get rewritten later when the temp file is written
// to disk.
if maxFileSizeBytes > 0 {
- reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes)))
+ reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes)))
}
contentLength = 0
}
@@ -759,6 +820,11 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string,
return contentLength, reader, nil
}
+// mediaMeta contains information about a multipart media response.
+// TODO: extend once something is defined.
+type mediaMeta struct{}
+
+// nolint: gocyclo
func (r *downloadRequest) fetchRemoteFile(
ctx context.Context,
client *fclient.Client,
@@ -767,19 +833,38 @@ func (r *downloadRequest) fetchRemoteFile(
) (types.Path, bool, error) {
r.Logger.Debug("Fetching remote file")
- // create request for remote file
- resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
+ // Attempt to download via authenticated media endpoint
+ isAuthed := true
+ resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
- if resp != nil && resp.StatusCode == http.StatusNotFound {
- return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
+ isAuthed = false
+ // try again on the unauthed endpoint
+ // create request for remote file
+ resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
+ if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) {
+ if resp != nil && resp.StatusCode == http.StatusNotFound {
+ return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
+ }
+ return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s: %w", r.MediaMetadata.MediaID, r.MediaMetadata.Origin, err)
}
- return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
}
defer resp.Body.Close() // nolint: errcheck
- // The reader returned here will be limited either by the Content-Length
- // and/or the configured maximum media size.
- contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes)
+ // If this wasn't a multipart response, set the Content-Type now. Will be overwritten
+ // by the multipart Content-Type below.
+ r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
+
+ var contentLength int64
+ var reader io.Reader
+ var parseErr error
+ if isAuthed {
+ parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes)
+ } else {
+ // The reader returned here will be limited either by the Content-Length
+ // and/or the configured maximum media size.
+ contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes)
+ }
+
if parseErr != nil {
return "", false, parseErr
}
@@ -790,7 +875,6 @@ func (r *downloadRequest) fetchRemoteFile(
}
r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength)
- r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
dispositionHeader := resp.Header.Get("Content-Disposition")
if _, params, e := mime.ParseMediaType(dispositionHeader); e == nil {
@@ -844,6 +928,50 @@ func (r *downloadRequest) fetchRemoteFile(
return types.Path(finalPath), duplicate, nil
}
+func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) {
+ _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
+ if err != nil {
+ return err, 0, nil
+ }
+ if params["boundary"] == "" {
+ return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil
+ }
+ mr := multipart.NewReader(resp.Body, params["boundary"])
+
+ // Get the first, JSON, part
+ p, err := mr.NextPart()
+ if err != nil {
+ return err, 0, nil
+ }
+ defer p.Close() // nolint: errcheck
+
+ if p.Header.Get("Content-Type") != "application/json" {
+ return fmt.Errorf("first part of the response must be application/json"), 0, nil
+ }
+ // Try to parse media meta information
+ meta := mediaMeta{}
+ if err = json.NewDecoder(p).Decode(&meta); err != nil {
+ return err, 0, nil
+ }
+ defer p.Close() // nolint: errcheck
+
+ // Get the actual media content
+ p, err = mr.NextPart()
+ if err != nil {
+ return err, 0, nil
+ }
+
+ redirect := p.Header.Get("Location")
+ if redirect != "" {
+ return fmt.Errorf("Location header is not yet supported"), 0, nil
+ }
+
+ contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes)
+ // For multipart requests, we need to get the Content-Type of the second part, which is the actual media
+ r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type"))
+ return err, contentLength, reader
+}
+
// contentDispositionFor returns the Content-Disposition for a given
// content type.
func contentDispositionFor(contentType types.ContentType) string {
diff --git a/mediaapi/routing/download_test.go b/mediaapi/routing/download_test.go
index 21f6bfc2..11368919 100644
--- a/mediaapi/routing/download_test.go
+++ b/mediaapi/routing/download_test.go
@@ -1,8 +1,13 @@
package routing
import (
+ "bytes"
+ "io"
+ "net/http"
+ "net/http/httptest"
"testing"
+ "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/stretchr/testify/assert"
)
@@ -11,3 +16,28 @@ func Test_dispositionFor(t *testing.T) {
assert.Equal(t, "attachment", contentDispositionFor("image/svg"), "image/svg")
assert.Equal(t, "inline", contentDispositionFor("image/jpeg"), "image/jpg")
}
+
+func Test_Multipart(t *testing.T) {
+ r := &downloadRequest{
+ MediaMetadata: &types.MediaMetadata{},
+ }
+ data := bytes.Buffer{}
+ responseBody := "This media is plain text. Maybe somebody used it as a paste bin."
+ data.WriteString(responseBody)
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ _, err := multipartResponse(w, r, "text/plain", &data)
+ assert.NoError(t, err)
+ }))
+ defer srv.Close()
+
+ resp, err := srv.Client().Get(srv.URL)
+ assert.NoError(t, err)
+ defer resp.Body.Close()
+ // contentLength is always 0, since there's no Content-Length header on the multipart part.
+ err, _, reader := parseMultipartResponse(r, resp, 1000)
+ assert.NoError(t, err)
+ gotResponse, err := io.ReadAll(reader)
+ assert.NoError(t, err)
+ assert.Equal(t, responseBody, string(gotResponse))
+}
diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go
index 5963eeaa..2867df60 100644
--- a/mediaapi/routing/routing.go
+++ b/mediaapi/routing/routing.go
@@ -20,11 +20,13 @@ import (
"strings"
"github.com/gorilla/mux"
+ "github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
@@ -45,15 +47,19 @@ type configResponse struct {
// applied:
// nolint: gocyclo
func Setup(
- publicAPIMux *mux.Router,
+ routers httputil.Routers,
cfg *config.Dendrite,
db storage.Database,
userAPI userapi.MediaUserAPI,
client *fclient.Client,
+ federationClient fclient.FederationClient,
+ keyRing gomatrixserverlib.JSONVerifier,
) {
rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting)
- v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
+ v3mux := routers.Media.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter()
+ v1mux := routers.Client.PathPrefix("/v1/media/").Subrouter()
+ v1fedMux := routers.Federation.PathPrefix("/v1/media/").Subrouter()
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
PathToResult: map[string]*types.ThumbnailGenerationResult{},
@@ -90,33 +96,103 @@ func Setup(
MXCToResult: map[string]*types.RemoteRequestResult{},
}
- downloadHandler := makeDownloadAPI("download", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration)
+ downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false)
v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/thumbnail/{serverName}/{mediaId}",
- makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration),
+ makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false),
).Methods(http.MethodGet, http.MethodOptions)
+
+ // v1 client endpoints requiring auth
+ downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth())
+ v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions)
+ v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
+ v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions)
+
+ v1mux.Handle("/thumbnail/{serverName}/{mediaId}",
+ httputil.MakeHTTPAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()),
+ ).Methods(http.MethodGet, http.MethodOptions)
+
+ // same, but for federation
+ v1fedMux.Handle("/download/{mediaId}", routing.MakeFedHTTPAPI(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
+ makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
+ )).Methods(http.MethodGet, http.MethodOptions)
+ v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedHTTPAPI(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing,
+ makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true),
+ )).Methods(http.MethodGet, http.MethodOptions)
}
+var thumbnailCounter = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Namespace: "dendrite",
+ Subsystem: "mediaapi",
+ Name: "thumbnail",
+ Help: "Total number of media_api requests for thumbnails",
+ },
+ []string{"code", "type"},
+)
+
+var thumbnailSize = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: "dendrite",
+ Subsystem: "mediaapi",
+ Name: "thumbnail_size_bytes",
+ Help: "Total size of media_api requests for thumbnails",
+ Buckets: []float64{50, 100, 200, 500, 900, 1500, 3000, 6000},
+ },
+ []string{"code", "type"},
+)
+
+var downloadCounter = promauto.NewCounterVec(
+ prometheus.CounterOpts{
+ Namespace: "dendrite",
+ Subsystem: "mediaapi",
+ Name: "download",
+ Help: "Total size of media_api requests for full downloads",
+ },
+ []string{"code", "type"},
+)
+
+var downloadSize = promauto.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Namespace: "dendrite",
+ Subsystem: "mediaapi",
+ Name: "download_size_bytes",
+ Help: "Total size of media_api requests for full downloads",
+ Buckets: []float64{1500, 3000, 6000, 10_000, 50_000, 100_000},
+ },
+ []string{"code", "type"},
+)
+
func makeDownloadAPI(
name string,
cfg *config.MediaAPI,
rateLimits *httputil.RateLimits,
db storage.Database,
client *fclient.Client,
+ fedClient fclient.FederationClient,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
+ forFederation bool,
) http.HandlerFunc {
var counterVec *prometheus.CounterVec
+ var sizeVec *prometheus.HistogramVec
+ var requestType string
if cfg.Matrix.Metrics.Enabled {
- counterVec = promauto.NewCounterVec(
- prometheus.CounterOpts{
- Name: name,
- Help: "Total number of media_api requests for either thumbnails or full downloads",
- },
- []string{"code"},
- )
+ split := strings.Split(name, "_")
+ // The first part of the split is either "download" or "thumbnail"
+ name = split[0]
+ // The remainder of the split is something like "authed_download" or "unauthed_thumbnail", etc.
+ // This is used to curry the metrics with the given types.
+ requestType = strings.Join(split[1:], "_")
+
+ counterVec = thumbnailCounter
+ sizeVec = thumbnailSize
+ if name != "thumbnail" {
+ counterVec = downloadCounter
+ sizeVec = downloadSize
+ }
}
httpHandler := func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)
@@ -164,16 +240,21 @@ func makeDownloadAPI(
cfg,
db,
client,
+ fedClient,
activeRemoteRequests,
activeThumbnailGeneration,
- name == "thumbnail",
+ strings.HasPrefix(name, "thumbnail"),
vars["downloadName"],
+ forFederation,
)
}
var handlerFunc http.HandlerFunc
if counterVec != nil {
+ counterVec = counterVec.MustCurryWith(prometheus.Labels{"type": requestType})
+ sizeVec2 := sizeVec.MustCurryWith(prometheus.Labels{"type": requestType})
handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler))
+ handlerFunc = promhttp.InstrumentHandlerResponseSize(sizeVec2, handlerFunc).ServeHTTP
} else {
handlerFunc = http.HandlerFunc(httpHandler)
}
diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go
index 09920308..7fd20f11 100644
--- a/roomserver/acls/acls_test.go
+++ b/roomserver/acls/acls_test.go
@@ -29,11 +29,11 @@ func TestOpenACLsWithBlacklist(t *testing.T) {
roomID := "!test:test.com"
allowRegex, err := compileACLRegex("*")
if err != nil {
- t.Fatalf(err.Error())
+ t.Fatal(err)
}
denyRegex, err := compileACLRegex("foo.com")
if err != nil {
- t.Fatalf(err.Error())
+ t.Fatal(err)
}
acls := ServerACLs{
@@ -72,7 +72,7 @@ func TestDefaultACLsWithWhitelist(t *testing.T) {
roomID := "!test:test.com"
allowRegex, err := compileACLRegex("foo.com")
if err != nil {
- t.Fatalf(err.Error())
+ t.Fatal(err)
}
acls := ServerACLs{
diff --git a/setup/monolith.go b/setup/monolith.go
index 4856d6e8..72750354 100644
--- a/setup/monolith.go
+++ b/setup/monolith.go
@@ -78,7 +78,7 @@ func (m *Monolith) AddAllPublicRoutes(
federationapi.AddPublicRoutes(
processCtx, routers, cfg, natsInstance, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, enableMetrics,
)
- mediaapi.AddPublicRoutes(routers.Media, cm, cfg, m.UserAPI, m.Client)
+ mediaapi.AddPublicRoutes(routers, cm, cfg, m.UserAPI, m.Client, m.FedClient, m.KeyRing)
syncapi.AddPublicRoutes(processCtx, routers, cfg, cm, natsInstance, m.UserAPI, m.RoomserverAPI, caches, enableMetrics)
if m.RelayAPI != nil {
diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go
index 422898c7..81127481 100644
--- a/userapi/internal/key_api.go
+++ b/userapi/internal/key_api.go
@@ -196,7 +196,7 @@ func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Quer
if m.StreamID > maxStreamID {
maxStreamID = m.StreamID
}
- if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
+ if len(m.KeyJSON) == 0 {
continue
}
result = append(result, m)