diff options
author | Till <2353100+S7evinK@users.noreply.github.com> | 2024-08-16 12:37:59 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-16 12:37:59 +0200 |
commit | 7a4ef240fc8ec97ba957933de3a80e611ad7d1f5 (patch) | |
tree | c8946995640907a3ea6e64a8a0509a23b696c69e /mediaapi | |
parent | 8c6cf51b8f6dd0f34ecc0f0b38d5475e2055a297 (diff) |
Implement MSC3916 (#3397)
Needs https://github.com/matrix-org/gomatrixserverlib/pull/437
Diffstat (limited to 'mediaapi')
-rw-r--r-- | mediaapi/mediaapi.go | 9 | ||||
-rw-r--r-- | mediaapi/routing/download.go | 164 | ||||
-rw-r--r-- | mediaapi/routing/download_test.go | 30 | ||||
-rw-r--r-- | mediaapi/routing/routing.go | 105 |
4 files changed, 275 insertions, 33 deletions
diff --git a/mediaapi/mediaapi.go b/mediaapi/mediaapi.go index 3425fbce..8b843e90 100644 --- a/mediaapi/mediaapi.go +++ b/mediaapi/mediaapi.go @@ -15,23 +15,26 @@ package mediaapi import ( - "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/routing" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/sirupsen/logrus" ) // AddPublicRoutes sets up and registers HTTP handlers for the MediaAPI component. func AddPublicRoutes( - mediaRouter *mux.Router, + routers httputil.Routers, cm *sqlutil.Connections, cfg *config.Dendrite, userAPI userapi.MediaUserAPI, client *fclient.Client, + fedClient fclient.FederationClient, + keyRing gomatrixserverlib.JSONVerifier, ) { mediaDB, err := storage.NewMediaAPIDatasource(cm, &cfg.MediaAPI.Database) if err != nil { @@ -39,6 +42,6 @@ func AddPublicRoutes( } routing.Setup( - mediaRouter, cfg, mediaDB, userAPI, client, + routers, cfg, mediaDB, userAPI, client, fedClient, keyRing, ) } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index fa1c417a..c812b9d6 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -21,7 +21,9 @@ import ( "io" "io/fs" "mime" + "mime/multipart" "net/http" + "net/textproto" "net/url" "os" "path/filepath" @@ -31,6 +33,7 @@ import ( "sync" "unicode" + "github.com/google/uuid" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -61,6 +64,9 @@ type downloadRequest struct { ThumbnailSize types.ThumbnailSize Logger *log.Entry DownloadFilename string + multipartResponse bool // whether we need to return a multipart/mixed response (for requests coming in over federation) + fedClient fclient.FederationClient + origin spec.ServerName } // Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84 @@ -111,11 +117,17 @@ func Download( cfg *config.MediaAPI, db storage.Database, client *fclient.Client, + fedClient fclient.FederationClient, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool, customFilename string, + federationRequest bool, ) { + // This happens if we call Download for a federation request + if federationRequest && origin == "" { + origin = cfg.Matrix.ServerName + } dReq := &downloadRequest{ MediaMetadata: &types.MediaMetadata{ MediaID: mediaID, @@ -126,7 +138,10 @@ func Download( "Origin": origin, "MediaID": mediaID, }), - DownloadFilename: customFilename, + DownloadFilename: customFilename, + multipartResponse: federationRequest, + origin: cfg.Matrix.ServerName, + fedClient: fedClient, } if dReq.IsThumbnailRequest { @@ -355,7 +370,7 @@ func (r *downloadRequest) respondFromLocalFile( }).Trace("Responding with file") responseFile = file responseMetadata = r.MediaMetadata - if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { + if err = r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { return nil, err } } @@ -367,14 +382,61 @@ func (r *downloadRequest) respondFromLocalFile( " plugin-types application/pdf;" + " style-src 'unsafe-inline';" + " object-src 'self';" - w.Header().Set("Content-Security-Policy", contentSecurityPolicy) - if _, err := io.Copy(w, responseFile); err != nil { - return nil, fmt.Errorf("io.Copy: %w", err) + if !r.multipartResponse { + w.Header().Set("Content-Security-Policy", contentSecurityPolicy) + if _, err = io.Copy(w, responseFile); err != nil { + return nil, fmt.Errorf("io.Copy: %w", err) + } + } else { + var written int64 + written, err = multipartResponse(w, r, string(responseMetadata.ContentType), responseFile) + if err != nil { + return nil, err + } + responseMetadata.FileSizeBytes = types.FileSizeBytes(written) } return responseMetadata, nil } +func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) { + // Update the header to be multipart/mixed; boundary=$randomBoundary + boundary := uuid.NewString() + w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary) + + w.Header().Del("Content-Length") // let Go handle the content length + mw := multipart.NewWriter(w) + defer func() { + if err := mw.Close(); err != nil { + r.Logger.WithError(err).Error("Failed to close multipart writer") + } + }() + + if err := mw.SetBoundary(boundary); err != nil { + return 0, fmt.Errorf("failed to set multipart boundary: %w", err) + } + + // JSON object part + jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + return 0, fmt.Errorf("failed to create json writer: %w", err) + } + if _, err = jsonWriter.Write([]byte("{}")); err != nil { + return 0, fmt.Errorf("failed to write to json writer: %w", err) + } + + // media part + mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {contentType}, + }) + if err != nil { + return 0, fmt.Errorf("failed to create media writer: %w", err) + } + return io.Copy(mediaWriter, responseFile) +} + func (r *downloadRequest) addDownloadFilenameToHeaders( w http.ResponseWriter, responseMetadata *types.MediaMetadata, @@ -722,8 +784,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( return nil } -func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { - reader := *body +func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { var contentLength int64 if contentLengthHeader != "" { @@ -742,7 +803,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, // We successfully parsed the Content-Length, so we'll return a limited // reader that restricts us to reading only up to this size. - reader = io.NopCloser(io.LimitReader(*body, parsedLength)) + reader = io.NopCloser(io.LimitReader(reader, parsedLength)) contentLength = parsedLength } else { // Content-Length header is missing. If we have a maximum file size @@ -751,7 +812,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, // ultimately it will get rewritten later when the temp file is written // to disk. if maxFileSizeBytes > 0 { - reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes))) + reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes))) } contentLength = 0 } @@ -759,6 +820,11 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, return contentLength, reader, nil } +// mediaMeta contains information about a multipart media response. +// TODO: extend once something is defined. +type mediaMeta struct{} + +// nolint: gocyclo func (r *downloadRequest) fetchRemoteFile( ctx context.Context, client *fclient.Client, @@ -767,19 +833,38 @@ func (r *downloadRequest) fetchRemoteFile( ) (types.Path, bool, error) { r.Logger.Debug("Fetching remote file") - // create request for remote file - resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) + // Attempt to download via authenticated media endpoint + isAuthed := true + resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) { - if resp != nil && resp.StatusCode == http.StatusNotFound { - return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + isAuthed = false + // try again on the unauthed endpoint + // create request for remote file + resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) + if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) { + if resp != nil && resp.StatusCode == http.StatusNotFound { + return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + } + return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s: %w", r.MediaMetadata.MediaID, r.MediaMetadata.Origin, err) } - return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) } defer resp.Body.Close() // nolint: errcheck - // The reader returned here will be limited either by the Content-Length - // and/or the configured maximum media size. - contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes) + // If this wasn't a multipart response, set the Content-Type now. Will be overwritten + // by the multipart Content-Type below. + r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) + + var contentLength int64 + var reader io.Reader + var parseErr error + if isAuthed { + parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes) + } else { + // The reader returned here will be limited either by the Content-Length + // and/or the configured maximum media size. + contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes) + } + if parseErr != nil { return "", false, parseErr } @@ -790,7 +875,6 @@ func (r *downloadRequest) fetchRemoteFile( } r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength) - r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) dispositionHeader := resp.Header.Get("Content-Disposition") if _, params, e := mime.ParseMediaType(dispositionHeader); e == nil { @@ -844,6 +928,50 @@ func (r *downloadRequest) fetchRemoteFile( return types.Path(finalPath), duplicate, nil } +func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) { + _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return err, 0, nil + } + if params["boundary"] == "" { + return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil + } + mr := multipart.NewReader(resp.Body, params["boundary"]) + + // Get the first, JSON, part + p, err := mr.NextPart() + if err != nil { + return err, 0, nil + } + defer p.Close() // nolint: errcheck + + if p.Header.Get("Content-Type") != "application/json" { + return fmt.Errorf("first part of the response must be application/json"), 0, nil + } + // Try to parse media meta information + meta := mediaMeta{} + if err = json.NewDecoder(p).Decode(&meta); err != nil { + return err, 0, nil + } + defer p.Close() // nolint: errcheck + + // Get the actual media content + p, err = mr.NextPart() + if err != nil { + return err, 0, nil + } + + redirect := p.Header.Get("Location") + if redirect != "" { + return fmt.Errorf("Location header is not yet supported"), 0, nil + } + + contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes) + // For multipart requests, we need to get the Content-Type of the second part, which is the actual media + r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type")) + return err, contentLength, reader +} + // contentDispositionFor returns the Content-Disposition for a given // content type. func contentDispositionFor(contentType types.ContentType) string { diff --git a/mediaapi/routing/download_test.go b/mediaapi/routing/download_test.go index 21f6bfc2..11368919 100644 --- a/mediaapi/routing/download_test.go +++ b/mediaapi/routing/download_test.go @@ -1,8 +1,13 @@ package routing import ( + "bytes" + "io" + "net/http" + "net/http/httptest" "testing" + "github.com/matrix-org/dendrite/mediaapi/types" "github.com/stretchr/testify/assert" ) @@ -11,3 +16,28 @@ func Test_dispositionFor(t *testing.T) { assert.Equal(t, "attachment", contentDispositionFor("image/svg"), "image/svg") assert.Equal(t, "inline", contentDispositionFor("image/jpeg"), "image/jpg") } + +func Test_Multipart(t *testing.T) { + r := &downloadRequest{ + MediaMetadata: &types.MediaMetadata{}, + } + data := bytes.Buffer{} + responseBody := "This media is plain text. Maybe somebody used it as a paste bin." + data.WriteString(responseBody) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, err := multipartResponse(w, r, "text/plain", &data) + assert.NoError(t, err) + })) + defer srv.Close() + + resp, err := srv.Client().Get(srv.URL) + assert.NoError(t, err) + defer resp.Body.Close() + // contentLength is always 0, since there's no Content-Length header on the multipart part. + err, _, reader := parseMultipartResponse(r, resp, 1000) + assert.NoError(t, err) + gotResponse, err := io.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, responseBody, string(gotResponse)) +} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 5963eeaa..2867df60 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -20,11 +20,13 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" @@ -45,15 +47,19 @@ type configResponse struct { // applied: // nolint: gocyclo func Setup( - publicAPIMux *mux.Router, + routers httputil.Routers, cfg *config.Dendrite, db storage.Database, userAPI userapi.MediaUserAPI, client *fclient.Client, + federationClient fclient.FederationClient, + keyRing gomatrixserverlib.JSONVerifier, ) { rateLimits := httputil.NewRateLimits(&cfg.ClientAPI.RateLimiting) - v3mux := publicAPIMux.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter() + v3mux := routers.Media.PathPrefix("/{apiversion:(?:r0|v1|v3)}/").Subrouter() + v1mux := routers.Client.PathPrefix("/v1/media/").Subrouter() + v1fedMux := routers.Federation.PathPrefix("/v1/media/").Subrouter() activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ PathToResult: map[string]*types.ThumbnailGenerationResult{}, @@ -90,33 +96,103 @@ func Setup( MXCToResult: map[string]*types.RemoteRequestResult{}, } - downloadHandler := makeDownloadAPI("download", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration) + downloadHandler := makeDownloadAPI("download_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false) v3mux.Handle("/download/{serverName}/{mediaId}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandler).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thumbnail/{serverName}/{mediaId}", - makeDownloadAPI("thumbnail", &cfg.MediaAPI, rateLimits, db, client, activeRemoteRequests, activeThumbnailGeneration), + makeDownloadAPI("thumbnail_unauthed", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), ).Methods(http.MethodGet, http.MethodOptions) + + // v1 client endpoints requiring auth + downloadHandlerAuthed := httputil.MakeHTTPAPI("download", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("download_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()) + v1mux.Handle("/config", configHandler).Methods(http.MethodGet, http.MethodOptions) + v1mux.Handle("/download/{serverName}/{mediaId}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) + v1mux.Handle("/download/{serverName}/{mediaId}/{downloadName}", downloadHandlerAuthed).Methods(http.MethodGet, http.MethodOptions) + + v1mux.Handle("/thumbnail/{serverName}/{mediaId}", + httputil.MakeHTTPAPI("thumbnail", userAPI, cfg.Global.Metrics.Enabled, makeDownloadAPI("thumbnail_authed_client", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, false), httputil.WithAuth()), + ).Methods(http.MethodGet, http.MethodOptions) + + // same, but for federation + v1fedMux.Handle("/download/{mediaId}", routing.MakeFedHTTPAPI(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, + makeDownloadAPI("download_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true), + )).Methods(http.MethodGet, http.MethodOptions) + v1fedMux.Handle("/thumbnail/{mediaId}", routing.MakeFedHTTPAPI(cfg.Global.ServerName, cfg.Global.IsLocalServerName, keyRing, + makeDownloadAPI("thumbnail_authed_federation", &cfg.MediaAPI, rateLimits, db, client, federationClient, activeRemoteRequests, activeThumbnailGeneration, true), + )).Methods(http.MethodGet, http.MethodOptions) } +var thumbnailCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "mediaapi", + Name: "thumbnail", + Help: "Total number of media_api requests for thumbnails", + }, + []string{"code", "type"}, +) + +var thumbnailSize = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "dendrite", + Subsystem: "mediaapi", + Name: "thumbnail_size_bytes", + Help: "Total size of media_api requests for thumbnails", + Buckets: []float64{50, 100, 200, 500, 900, 1500, 3000, 6000}, + }, + []string{"code", "type"}, +) + +var downloadCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "mediaapi", + Name: "download", + Help: "Total size of media_api requests for full downloads", + }, + []string{"code", "type"}, +) + +var downloadSize = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "dendrite", + Subsystem: "mediaapi", + Name: "download_size_bytes", + Help: "Total size of media_api requests for full downloads", + Buckets: []float64{1500, 3000, 6000, 10_000, 50_000, 100_000}, + }, + []string{"code", "type"}, +) + func makeDownloadAPI( name string, cfg *config.MediaAPI, rateLimits *httputil.RateLimits, db storage.Database, client *fclient.Client, + fedClient fclient.FederationClient, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, + forFederation bool, ) http.HandlerFunc { var counterVec *prometheus.CounterVec + var sizeVec *prometheus.HistogramVec + var requestType string if cfg.Matrix.Metrics.Enabled { - counterVec = promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: name, - Help: "Total number of media_api requests for either thumbnails or full downloads", - }, - []string{"code"}, - ) + split := strings.Split(name, "_") + // The first part of the split is either "download" or "thumbnail" + name = split[0] + // The remainder of the split is something like "authed_download" or "unauthed_thumbnail", etc. + // This is used to curry the metrics with the given types. + requestType = strings.Join(split[1:], "_") + + counterVec = thumbnailCounter + sizeVec = thumbnailSize + if name != "thumbnail" { + counterVec = downloadCounter + sizeVec = downloadSize + } } httpHandler := func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -164,16 +240,21 @@ func makeDownloadAPI( cfg, db, client, + fedClient, activeRemoteRequests, activeThumbnailGeneration, - name == "thumbnail", + strings.HasPrefix(name, "thumbnail"), vars["downloadName"], + forFederation, ) } var handlerFunc http.HandlerFunc if counterVec != nil { + counterVec = counterVec.MustCurryWith(prometheus.Labels{"type": requestType}) + sizeVec2 := sizeVec.MustCurryWith(prometheus.Labels{"type": requestType}) handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + handlerFunc = promhttp.InstrumentHandlerResponseSize(sizeVec2, handlerFunc).ServeHTTP } else { handlerFunc = http.HandlerFunc(httpHandler) } |