diff options
Diffstat (limited to 'mediaapi/routing/download.go')
-rw-r--r-- | mediaapi/routing/download.go | 164 |
1 files changed, 146 insertions, 18 deletions
diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index fa1c417a..c812b9d6 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -21,7 +21,9 @@ import ( "io" "io/fs" "mime" + "mime/multipart" "net/http" + "net/textproto" "net/url" "os" "path/filepath" @@ -31,6 +33,7 @@ import ( "sync" "unicode" + "github.com/google/uuid" "github.com/matrix-org/dendrite/mediaapi/fileutils" "github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/thumbnailer" @@ -61,6 +64,9 @@ type downloadRequest struct { ThumbnailSize types.ThumbnailSize Logger *log.Entry DownloadFilename string + multipartResponse bool // whether we need to return a multipart/mixed response (for requests coming in over federation) + fedClient fclient.FederationClient + origin spec.ServerName } // Taken from: https://github.com/matrix-org/synapse/blob/c3627d0f99ed5a23479305dc2bd0e71ca25ce2b1/synapse/media/_base.py#L53C1-L84 @@ -111,11 +117,17 @@ func Download( cfg *config.MediaAPI, db storage.Database, client *fclient.Client, + fedClient fclient.FederationClient, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool, customFilename string, + federationRequest bool, ) { + // This happens if we call Download for a federation request + if federationRequest && origin == "" { + origin = cfg.Matrix.ServerName + } dReq := &downloadRequest{ MediaMetadata: &types.MediaMetadata{ MediaID: mediaID, @@ -126,7 +138,10 @@ func Download( "Origin": origin, "MediaID": mediaID, }), - DownloadFilename: customFilename, + DownloadFilename: customFilename, + multipartResponse: federationRequest, + origin: cfg.Matrix.ServerName, + fedClient: fedClient, } if dReq.IsThumbnailRequest { @@ -355,7 +370,7 @@ func (r *downloadRequest) respondFromLocalFile( }).Trace("Responding with file") responseFile = file responseMetadata = r.MediaMetadata - if err := r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { + if err = r.addDownloadFilenameToHeaders(w, responseMetadata); err != nil { return nil, err } } @@ -367,14 +382,61 @@ func (r *downloadRequest) respondFromLocalFile( " plugin-types application/pdf;" + " style-src 'unsafe-inline';" + " object-src 'self';" - w.Header().Set("Content-Security-Policy", contentSecurityPolicy) - if _, err := io.Copy(w, responseFile); err != nil { - return nil, fmt.Errorf("io.Copy: %w", err) + if !r.multipartResponse { + w.Header().Set("Content-Security-Policy", contentSecurityPolicy) + if _, err = io.Copy(w, responseFile); err != nil { + return nil, fmt.Errorf("io.Copy: %w", err) + } + } else { + var written int64 + written, err = multipartResponse(w, r, string(responseMetadata.ContentType), responseFile) + if err != nil { + return nil, err + } + responseMetadata.FileSizeBytes = types.FileSizeBytes(written) } return responseMetadata, nil } +func multipartResponse(w http.ResponseWriter, r *downloadRequest, contentType string, responseFile io.Reader) (int64, error) { + // Update the header to be multipart/mixed; boundary=$randomBoundary + boundary := uuid.NewString() + w.Header().Set("Content-Type", "multipart/mixed; boundary="+boundary) + + w.Header().Del("Content-Length") // let Go handle the content length + mw := multipart.NewWriter(w) + defer func() { + if err := mw.Close(); err != nil { + r.Logger.WithError(err).Error("Failed to close multipart writer") + } + }() + + if err := mw.SetBoundary(boundary); err != nil { + return 0, fmt.Errorf("failed to set multipart boundary: %w", err) + } + + // JSON object part + jsonWriter, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + return 0, fmt.Errorf("failed to create json writer: %w", err) + } + if _, err = jsonWriter.Write([]byte("{}")); err != nil { + return 0, fmt.Errorf("failed to write to json writer: %w", err) + } + + // media part + mediaWriter, err := mw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {contentType}, + }) + if err != nil { + return 0, fmt.Errorf("failed to create media writer: %w", err) + } + return io.Copy(mediaWriter, responseFile) +} + func (r *downloadRequest) addDownloadFilenameToHeaders( w http.ResponseWriter, responseMetadata *types.MediaMetadata, @@ -722,8 +784,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( return nil } -func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, body *io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { - reader := *body +func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, reader io.ReadCloser, maxFileSizeBytes config.FileSizeBytes) (int64, io.Reader, error) { var contentLength int64 if contentLengthHeader != "" { @@ -742,7 +803,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, // We successfully parsed the Content-Length, so we'll return a limited // reader that restricts us to reading only up to this size. - reader = io.NopCloser(io.LimitReader(*body, parsedLength)) + reader = io.NopCloser(io.LimitReader(reader, parsedLength)) contentLength = parsedLength } else { // Content-Length header is missing. If we have a maximum file size @@ -751,7 +812,7 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, // ultimately it will get rewritten later when the temp file is written // to disk. if maxFileSizeBytes > 0 { - reader = io.NopCloser(io.LimitReader(*body, int64(maxFileSizeBytes))) + reader = io.NopCloser(io.LimitReader(reader, int64(maxFileSizeBytes))) } contentLength = 0 } @@ -759,6 +820,11 @@ func (r *downloadRequest) GetContentLengthAndReader(contentLengthHeader string, return contentLength, reader, nil } +// mediaMeta contains information about a multipart media response. +// TODO: extend once something is defined. +type mediaMeta struct{} + +// nolint: gocyclo func (r *downloadRequest) fetchRemoteFile( ctx context.Context, client *fclient.Client, @@ -767,19 +833,38 @@ func (r *downloadRequest) fetchRemoteFile( ) (types.Path, bool, error) { r.Logger.Debug("Fetching remote file") - // create request for remote file - resp, err := client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) + // Attempt to download via authenticated media endpoint + isAuthed := true + resp, err := r.fedClient.DownloadMedia(ctx, r.origin, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) { - if resp != nil && resp.StatusCode == http.StatusNotFound { - return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + isAuthed = false + // try again on the unauthed endpoint + // create request for remote file + resp, err = client.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) + if err != nil || (resp != nil && resp.StatusCode != http.StatusOK) { + if resp != nil && resp.StatusCode == http.StatusNotFound { + return "", false, fmt.Errorf("File with media ID %q does not exist on %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) + } + return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s: %w", r.MediaMetadata.MediaID, r.MediaMetadata.Origin, err) } - return "", false, fmt.Errorf("file with media ID %q could not be downloaded from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) } defer resp.Body.Close() // nolint: errcheck - // The reader returned here will be limited either by the Content-Length - // and/or the configured maximum media size. - contentLength, reader, parseErr := r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), &resp.Body, maxFileSizeBytes) + // If this wasn't a multipart response, set the Content-Type now. Will be overwritten + // by the multipart Content-Type below. + r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) + + var contentLength int64 + var reader io.Reader + var parseErr error + if isAuthed { + parseErr, contentLength, reader = parseMultipartResponse(r, resp, maxFileSizeBytes) + } else { + // The reader returned here will be limited either by the Content-Length + // and/or the configured maximum media size. + contentLength, reader, parseErr = r.GetContentLengthAndReader(resp.Header.Get("Content-Length"), resp.Body, maxFileSizeBytes) + } + if parseErr != nil { return "", false, parseErr } @@ -790,7 +875,6 @@ func (r *downloadRequest) fetchRemoteFile( } r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength) - r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) dispositionHeader := resp.Header.Get("Content-Disposition") if _, params, e := mime.ParseMediaType(dispositionHeader); e == nil { @@ -844,6 +928,50 @@ func (r *downloadRequest) fetchRemoteFile( return types.Path(finalPath), duplicate, nil } +func parseMultipartResponse(r *downloadRequest, resp *http.Response, maxFileSizeBytes config.FileSizeBytes) (error, int64, io.Reader) { + _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return err, 0, nil + } + if params["boundary"] == "" { + return fmt.Errorf("no boundary header found on media %s from %s", r.MediaMetadata.MediaID, r.MediaMetadata.Origin), 0, nil + } + mr := multipart.NewReader(resp.Body, params["boundary"]) + + // Get the first, JSON, part + p, err := mr.NextPart() + if err != nil { + return err, 0, nil + } + defer p.Close() // nolint: errcheck + + if p.Header.Get("Content-Type") != "application/json" { + return fmt.Errorf("first part of the response must be application/json"), 0, nil + } + // Try to parse media meta information + meta := mediaMeta{} + if err = json.NewDecoder(p).Decode(&meta); err != nil { + return err, 0, nil + } + defer p.Close() // nolint: errcheck + + // Get the actual media content + p, err = mr.NextPart() + if err != nil { + return err, 0, nil + } + + redirect := p.Header.Get("Location") + if redirect != "" { + return fmt.Errorf("Location header is not yet supported"), 0, nil + } + + contentLength, reader, err := r.GetContentLengthAndReader(p.Header.Get("Content-Length"), p, maxFileSizeBytes) + // For multipart requests, we need to get the Content-Type of the second part, which is the actual media + r.MediaMetadata.ContentType = types.ContentType(p.Header.Get("Content-Type")) + return err, contentLength, reader +} + // contentDispositionFor returns the Content-Disposition for a given // content type. func contentDispositionFor(contentType types.ContentType) string { |